import torch
from torch.autograd.functional import jvp
from functools import partial

def symmreg(x, generator, pde, f=None, dfdx=None, require_grad=False):
    if f is None and dfdx is None:
        raise ValueError('Either f or dfdx must be specified.')
    if f is not None and dfdx is not None:
        raise ValueError('Only one of f and dfdx can be specified.')   
    jvp_fn = partial(jvp, create_graph=True, strict=True) if require_grad else jvp
    if pde in ['KdV', 'KS', 'Burgers']:
        x_tensor = torch.stack([x['u'], x['dudx'], x['dudxdx'], x['dudxdxdx'], x['dudxdxdxdx'], x['dudt']], dim=1)
    elif pde == 'nKdV':
        x_tensor = torch.stack([x['t'], x['u'], x['dudx'], x['dudxdx'], x['dudxdxdx'], x['dudxdxdxdx'], x['dudt']], dim=1)

    with torch.set_grad_enabled(require_grad):
        loss = 0.0
        v_x = generator(**x)
        for vi_x in v_x:
            vi_x = torch.stack(vi_x).T
            if f is not None:      
                input_variation = jvp_fn(f, x_tensor, vi_x)[1]
            elif dfdx is not None:
                input_variation = torch.einsum('bjk,bk->bj', dfdx, vi_x)
            loss += torch.mean(input_variation ** 2)

    return loss

make_symmreg_pttrain = lambda generator: partial(symmreg, generator=generator, require_grad=True)