from .modules import *
from .models import *
from typing import Union

def bmv(M, v):
    return (M @ v.unsqueeze(-1)).squeeze(-1)


def project_vec2contours(f_hat: torch.Tensor, grad_V: Union[torch.Tensor, None]):
    if grad_V is None or grad_V.numel() == 0:
        return f_hat
    # x1: #samples x #states
    # f_hat #samples x #states
    # grad_V #samples x #states x #quantities

    Mt = grad_V
    M = Mt.transpose(-1, -2)
    f = f_hat - bmv(Mt, torch.linalg.solve(M @ Mt, bmv(M, f_hat)))

    return f


class ContinuousFINDE(Module):
    is_discrete = False
    is_continuous = True

    def __init__(self, quantities: Union[Module, None], n_quantities: int, keeprate: float = 1.0, hnn=None):
        super(ContinuousFINDE, self).__init__()
        self.n_quantities = n_quantities
        self.quantities = quantities
        self.keeprate = keeprate
        self.used_quantities = np.ones(n_quantities).astype(bool)
        self.hnn = hnn

    def train(self, mode):
        if mode and self.keeprate < 1.0:
            self.used_quantities = np.random.uniform(0, 1, self.n_quantities) < self.keeprate
        else:
            self.used_quantities = np.ones(self.n_quantities).astype(bool)
        return super(ContinuousFINDE, self).train(mode)

    def get_grad_V(self, x1):
        from functorch import jacrev, vmap
        grad_V = [torch.empty(*x1.shape, 0).to(x1), ]
        if self.hnn is not None:
            grad_V.append(self.hnn.grad(x1).unsqueeze(-1))
        if self.quantities is not None and np.any(self.used_quantities):
            used_quantities = torch.tensor(np.where(self.used_quantities)[0]).to(device=x1.device, dtype=torch.int)
            grad_V += [torch.index_select(vmap(jacrev(self.quantities))(x1).transpose(-1, -2), -1, used_quantities), ]
        return torch.cat(grad_V, dim=-1)

    def project_to_TuM(self, x1, f_hat):
        grad_V = self.get_grad_V(x1)
        f = project_vec2contours(f_hat, grad_V)
        return f

class ContinuousFINDE2Pend(ContinuousFINDE):

    def get_grad_V_2pend(self, u):
        u = u.requires_grad_(True)
        with torch.enable_grad():
            x1, y1, x2, y2, px1, py1, px2, py2 = torch.chunk(u, 8, dim=-1)
            v = [
                (x1**2 + y1**2).sum(),
                ((x1 - x2)**2 + (y1 - y2)**2).sum(),
                (x1 * px1 + y1 * py1).sum(),
                ((x1 - x2) * (px1 - px2) + (y1 - y2) * (py1 - py2)).sum(),
            ]
        grad_V = [torch.autograd.grad(vi, (u,), create_graph=True)[0] for vi in v]

        return torch.stack(grad_V, dim=-1)

    def get_grad_V(self, x1):
        grad_V_2pend = self.get_grad_V_2pend(x1)
        grad_V_net = super(ContinuousFINDE2Pend, self).get_grad_V(x1)
        return torch.cat([grad_V_2pend, grad_V_net], dim=-1)

class ContinuousFINDE2Body(ContinuousFINDE):

    def get_grad_V_2body(self, u):
        u = u.requires_grad_(True)
        with torch.enable_grad():
            x1, x2, y1, y2, px1, px2, py1, py2, = torch.chunk(u, 8, dim=-1)
            v = [
                (px1 + px2).sum(),
                (py1 + py2).sum(),
            ]
        grad_V = [torch.autograd.grad(vi, (u,), create_graph=True)[0] for vi in v]

        return torch.stack(grad_V, dim=-1)

    def get_grad_V(self, x1):
        grad_V_2body = self.get_grad_V_2body(x1)
        grad_V_net = super(ContinuousFINDE2Body, self).get_grad_V(x1)
        return torch.cat([grad_V_2body, grad_V_net], dim=-1)




class DiscreteFINDE(ContinuousFINDE):
    is_discrete = True
    is_continuous = False

    def get_discrete_grad_V(self, x1, x2):
        grad_V = [torch.empty(*x1.shape, 0).to(x1), ]
        if self.hnn:
            grad_V.append(self.hnn.discrete_grad(x1, x2).unsqueeze(-1))
        if self.quantities is not None and np.any(self.used_quantities):
            with torch.enable_grad():
                x1 = x1.requires_grad_(True)
                x2 = x2.requires_grad_(True)
                h, _ = self.quantities(x1, x2)
                # We know this is inefficient, but it's time consuming to extend discrete gradients to discrete Jacobians.
                set_discrete_autograd_mode(True)
                for i in range(h.shape[-1]):
                    if self.used_quantities[i]:
                        grad_V.append(torch.autograd.grad(h[..., i].sum(), (x1,), create_graph=True, retain_graph=True)[0].unsqueeze(-1))
                set_discrete_autograd_mode(False)
        return torch.cat(grad_V, dim=-1)

    def project_to_TvuM(self, x1, psi_hat, dt, x2=None):
        if x2 is not None:
            grad_V = self.get_discrete_grad_V(x1, x2)
            psi = project_vec2contours(psi_hat, grad_V)
        else:
            x2 = x1 + dt * psi_hat
            x2 = fsolve_gpu(lambda xp: self.project_to_TvuM(x1, psi_hat=psi_hat, dt=None, x2=xp) * dt - (xp - x1), x2)
            psi = (x2 - x1) / dt
        return psi


class HybridFINDE(DiscreteFINDE):
    is_discrete = True
    is_continuous = True

def project_vec2contours_pde1d(f_hat: torch.Tensor, grad_V: Union[torch.Tensor, None]):
    if grad_V is None or grad_V.numel() == 0:
        return f_hat
    assert f_hat.shape[1]==1
    assert grad_V.shape[1]==1
    f_hat=f_hat.squeeze(1)
    grad_V=grad_V.squeeze(1)

    Mt = grad_V
    M = Mt.transpose(-1, -2)
    f = f_hat - bmv(Mt, torch.linalg.solve(M @ Mt, bmv(M, f_hat)))
    f=f.unsqueeze(1)
    return f

class ContinuousFINDEPDE1d(ContinuousFINDE):
    is_discrete = False
    is_continuous = True


    def get_grad_V(self, x1):
        grad_V = [torch.empty(*x1.shape, 0).to(x1), ]
        if self.hnn is not None:
            grad_V.append(self.hnn.grad(x1).unsqueeze(-1))
        if self.quantities is not None and np.any(self.used_quantities):
            with torch.enable_grad():
                x1 = x1.requires_grad_(True)
                h = self.quantities(x1).squeeze(-1) # remove spatial dimension
                for i in range(h.shape[-1]):
                    if self.used_quantities[i]:
                        grad_V.append(torch.autograd.grad(h[..., i].sum(), (x1,), create_graph=True, retain_graph=True)[0].unsqueeze(-1))
        return torch.cat(grad_V, dim=-1)

    def project_to_TuM(self, x1, f_hat):
        grad_V = self.get_grad_V(x1)
        f = project_vec2contours_pde1d(f_hat, grad_V)
        return f

class DiscreteFINDEPDE1d(DiscreteFINDE):
    is_discrete = True
    is_continuous = False

    def get_discrete_grad_V(self, x1, x2):
        grad_V = [torch.empty(*x1.shape, 0).to(x1), ]
        if self.hnn:
            grad_V.append(self.hnn.discrete_grad(x1, x2).unsqueeze(-1))
        if self.quantities is not None and np.any(self.used_quantities):
            with torch.enable_grad():
                x1 = x1.requires_grad_(True)
                x2 = x2.requires_grad_(True)
                h = self.quantities(x1, x2)[0].squeeze(-1) # remove spatial dimension
                # We know this is inefficient, but it's time consuming to extend discrete gradients to discrete Jacobians.
                set_discrete_autograd_mode(True)
                for i in range(h.shape[-1]):
                    if self.used_quantities[i]:
                        grad_V.append(torch.autograd.grad(h[..., i].sum(), (x1,), create_graph=True, retain_graph=True)[0].unsqueeze(-1))
                set_discrete_autograd_mode(False)
        return torch.cat(grad_V, dim=-1)

    def project_to_TvuM(self, x1, psi_hat, dt, x2=None):
        if x2 is not None:
            grad_V = self.get_discrete_grad_V(x1, x2)
            psi = project_vec2contours_pde1d(psi_hat, grad_V)
        else:
            x2 = x1 + dt * psi_hat
            x2 = fsolve_gpu(lambda xp: self.project_to_TvuM(x1, psi_hat=psi_hat, dt=None, x2=xp) * dt - (xp - x1), x2)
            psi = (x2 - x1) / dt
        return psi




def get_finde(finde, input_dim, hidden_dim, act, model, data_mean=None, data_std=None, quantities=None):
    hnn = model
    if quantities is None and finde.num > 0:
        quantities = get_MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=finde.num, act=act, bias=False, data_mean=data_mean, data_std=data_std)
    variant = finde.variant
    if not finde.hnn:
        hnn = None
    if variant == 'continuous':
        return ContinuousFINDE(quantities, n_quantities=finde.num, keeprate=finde.keeprate, hnn=hnn)
    if variant == 'continuous2pend':
        return ContinuousFINDE2Pend(quantities, n_quantities=finde.num, keeprate=finde.keeprate, hnn=hnn)
    if variant == 'continuous2body':
        return ContinuousFINDE2Body(quantities, n_quantities=finde.num, keeprate=finde.keeprate, hnn=hnn)
    if variant == 'discrete':
        return DiscreteFINDE(quantities, n_quantities=finde.num, keeprate=finde.keeprate, hnn=hnn)
    if variant == 'hybrid':
        return HybridFINDE(quantities, n_quantities=finde.num, keeprate=finde.keeprate, hnn=hnn)
    if variant == 'continuousPDE':
        return ContinuousFINDEPDE1d(quantities, n_quantities=finde.num, keeprate=finde.keeprate, hnn=hnn)
    if variant == 'discretePDE':
        return DiscreteFINDEPDE1d(quantities, n_quantities=finde.num, keeprate=finde.keeprate, hnn=hnn)

    raise NotImplementedError(finde.variant)
