How to batch wise grad
https://discuss.pytorch.org/t/vmap-over-autograd-grad-of-a-nn-module/179038/2
jaxrev vmap虽然promising, 但是它们每次都要前向传播
这两天实现了一下,这个vmap和functional_call的思路实现复杂的含有多次求导的loss恐怕需要多次前向传播,因为方程中有很多很多的导数,而pytorch(甚至包括jax)的函数要求每求一次导,都要从头做一个前向传播。所以这个方案对于一些简单的physics informed loss是可以用的。
但现在我感觉还不如第一层导数求和求就好,第二层导数还不如循环一下,这样虽然循环了,但是前向传播次数只有一次