import numpy as np
import torch
import torch.nn
import torch.nn.functional as F
import scipy.optimize
fsolve = scipy.optimize.fsolve



class PDENd:
    def __init__(self, ndiv=50, width=1., device='cpu'):
        self.ndiv = ndiv
        self.width = width
        self.dx = width / self.ndiv
        self.device = device
        self.dim = 0

    def _pow(self, u, nu, arg):
        if nu is None:
            return u**arg
        if arg == 1:
            return (u + nu) * 0.5
        if arg == 2:
            return (u * u + u * nu + nu * nu) / 3.
        if arg == 3:
            return (nu * nu * nu + nu * nu * u + nu * u * u + u * u * u) * 0.25
        raise NotImplementedError

    def dudtdiff(self, u, nu, dt, t):
        dudt_right = self.dudt(u=u, nu=nu, t=t)
        dudt_left = (nu - u) / dt
        diff = dudt_left - dudt_right
        return diff

    def dvdmint(self, x0, t_eval):
        if len(x0.shape) == 1:
            x0 = x0.reshape(1, *x0.shape)
        if len(x0.shape) == 2:
            x0 = x0.reshape(1, *x0.shape)
        res = [x0, ]
        dts = t_eval[1:] - t_eval[:-1]
        xn = x0
        for i in range(t_eval.size - 1):
            print(i, '/', t_eval.size - 1, end='\r')
            xn1 = xn + dts[i] * self.dudt(u=xn, t=t_eval[i])
            xn1 = fsolve(lambda xnp: self.dudtdiff(u=xn, nu=xnp.reshape(*xn.shape), dt=dts[i], t=t_eval[i]).reshape(-1), xn1.reshape(-1), xtol=1.0e-10)
            assert isinstance(xn1, np.ndarray)
            xn = xn1.reshape(*xn.shape)
            res.append(xn)
        res = np.stack(res, axis=0)
        return res

    def _global_integral(self, u):
        return u.view(*u.shape[:-self.dim], -1).sum(-1) * self.dx**self.dim

    def dudt(self, u, *, nu=None, t=None):
        raise NotImplementedError

    def _D(self, u):
        raise NotImplementedError

    def _D2(self, u):
        raise NotImplementedError

    def _Dx2(self, u):
        raise NotImplementedError


class PDE1dPeriodic(PDENd):
    def __init__(self):
        self.dim = 1
        self.kernel_dx2 = torch.tensor(
            [1., -2., 1.],
            dtype=torch.get_default_dtype(),
            device=self.device).view(1, 1, 3) / (self.dx * self.dx)
        self.kernel_dx_central = torch.tensor(
            [-1., 0., 1.],
            dtype=torch.get_default_dtype(),
            device=self.device).view(1, 1, 3) / (2. * self.dx)
        self.kernel_dx_forward = torch.tensor(
            [0., -1., 1.],
            dtype=torch.get_default_dtype(),
            device=self.device).view(1, 1, 3) / self.dx
        self.kernel_dx_backward = torch.tensor(
            [-1., 1., 0.],
            dtype=torch.get_default_dtype(),
            device=self.device).view(1, 1, 3) / self.dx

    def _D(self, u):
        shape = u.shape
        u = u.view(-1, 1, shape[-1])
        u_pad = torch.cat([u[..., -1:], u, u[..., :1]], dim=-1)
        u_conv = F.conv1d(u_pad, self.kernel_dx_central, padding=0)
        u_conv = u_conv.view(shape)
        return u_conv

    def _D2(self, u):
        shape = u.shape
        u = u.view(-1, 1, shape[-1])
        u_pad = torch.cat([u[..., -1:], u, u[..., :1]], dim=-1)
        u_conv = F.conv1d(u_pad, self.kernel_dx2, padding=0)
        u_conv = u_conv.view(shape)
        return u_conv

    def _Dx2(self, u):
        shape = u.shape
        u = u.view(-1, 1, shape[-1])
        u_pad = torch.cat([u[..., -1:], u, u[..., :1]], dim=-1)
        conved_u_forward = F.conv1d(u_pad, self.kernel_dx_forward, padding=0)
        conved_u_backward = F.conv1d(u_pad, self.kernel_dx_backward, padding=0)
        u2 = (conved_u_forward**2 + conved_u_backward**2) / 2
        u2 = u2.view(shape)
        return u2


class CahnHilliardNd(PDENd):
    def __init__(self, ndiv=50, width=1., a=1., b=1., gamma=0.005, device='cpu'):
        PDENd.__init__(self, ndiv=ndiv, width=width, device=device)
        self.a = a
        self.b = b
        self.gamma = gamma

    def dudt(self, u, *, nu=None, t=None):
        u = torch.from_numpy(u).to(dtype=torch.get_default_dtype(), device=self.device)
        if nu is not None:
            nu = torch.from_numpy(nu).to(dtype=torch.get_default_dtype(), device=self.device)
        u1 = self._pow(u, nu, 1)
        u3 = self._pow(u, nu, 3)
        dudt = self._D2(- self.a * u1 + self.b * u3 - self.gamma * self._D2(u1))
        return dudt.detach().cpu().numpy()

    def get_energy(self, u):
        u = torch.tensor(u, dtype=torch.get_default_dtype(), device=self.device)
        local_energy = -0.5 * self.a * u**2 + 0.25 * self.b * u**4 + 0.5 * self.gamma * self._Dx2(u)
        energy = self._global_integral(local_energy)
        return energy.detach().cpu().numpy()


class KdVNd(PDENd):
    def __init__(self, ndiv=50, width=1., a=6., b=1., device='cpu'):
        PDENd.__init__(self, ndiv=ndiv, width=width, device=device)
        self.a = a
        self.b = b

    def dudt(self, u, *, nu=None, t=None):
        u = torch.from_numpy(u).to(dtype=torch.get_default_dtype(), device=self.device)
        if nu is not None:
            nu = torch.from_numpy(nu).to(dtype=torch.get_default_dtype(), device=self.device)
        u1 = self._pow(u, nu, 1)
        u2 = self._pow(u, nu, 2)
        dudt = self._D(- self.a / 2 * u2 + self.b * self._D2(u1))
        return dudt.detach().cpu().numpy()

    def get_energy(self, u):
        u = torch.tensor(u, dtype=torch.get_default_dtype(), device=self.device)
        local_energy = -self.a / 6. * u**3 - self.b / 2. * self._Dx2(u)
        energy = self._global_integral(local_energy)
        return energy.detach().cpu().numpy()


class CahnHilliard1d(CahnHilliardNd, PDE1dPeriodic):
    def __init__(self, ndiv=50, width=1., a=1., b=1., gamma=0.005, device='cpu'):
        CahnHilliardNd.__init__(self, ndiv=ndiv, width=width, a=a, b=b, gamma=gamma, device=device)
        PDE1dPeriodic.__init__(self)


class KdV1d(KdVNd, PDE1dPeriodic):
    def __init__(self, ndiv=50, width=1., a=-6., b=1., device='cpu'):
        KdVNd.__init__(self, ndiv=ndiv, width=width, a=a, b=b, device=device)
        PDE1dPeriodic.__init__(self)
