import torch
from torch import nn, optim
import copy
import math
import utils
from utils import cat_tx, cat_stx
import sys
import NMC

class MLP(nn.Module):
    def __init__(self, d_in, d_out, hidden_sizes = [100, ], activation = nn.ReLU, output_activation = None):
        super(MLP, self).__init__()
        self.net = nn.Sequential()
        assert len(hidden_sizes) > 0
        hidden_sizes = copy.copy(hidden_sizes)
        hidden_sizes.insert(0, d_in)
        hidden_sizes.append(d_out)
        for i in range(len(hidden_sizes)-1):
            self.net.add_module(name = f"L{i}", module = nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
            if i < len(hidden_sizes)-2:
                self.net.add_module(name = f"A{i}", module = activation())
        if output_activation is not None:
            self.net.add_module(name = f"A{i}", module = output_activation())
        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                # nn.init.normal_(m.weight, mean=0, std=0.1)
                # nn.init.normal_(m.bias, mean=0, std=0)
                nn.init.xavier_normal_(m.weight)
                nn.init.zeros_(m.bias)
    def forward(self, x):
        return self.net(x)
    def jacobian(self, x):
        return torch.vmap(torch.func.jacrev(self.net))(x)

class RescaledNCScoreFunc(nn.Module):
    def __init__(self, d = 2, time_dependent = False, **kwargs):
        super(RescaledNCScoreFunc, self).__init__()
        self.time_dependent = time_dependent
        self.net = MLP(d_in = d+1 if time_dependent else d, d_out = d, **kwargs)
    def forward(self, t, x, s):
        if self.time_dependent:
            y = torch.hstack([x, t.expand(*x.shape[:-1], 1)])
        else:
            y = x
        return self.net(y) / s

class NCScoreFunc(nn.Module):
    def __init__(self, d = 2, time_dependent = False, **kwargs):
        super(NCScoreFunc, self).__init__()
        self.time_dependent = time_dependent
        self.net = MLP(d_in = d+2 if time_dependent else d+1, d_out = d, **kwargs)
    def forward(self, t, x, s):
        if self.time_dependent:
            y = torch.hstack([x, t.expand(*x.shape[:-1], 1), s.expand(*x.shape[:-1], 1)])
        else:
            y = torch.hstack([x, s.expand(*x.shape[:-1], 1)])
        return self.net(y)
    
class CondNCScoreFunc(nn.Module):
    def __init__(self, d = 2, num_classes = 1, time_dependent = False, **kwargs):
        super(CondNCScoreFunc, self).__init__()
        self.time_dependent = time_dependent
        self.num_classes = num_classes
        self.net = nn.Sequential()
        self.net = MLP(d_in = d + num_classes + 2 if time_dependent else d + num_classes + 1, d_out = d, **kwargs)
    def forward(self, t, x, k, s):
        y = torch.hstack([x, t.expand(*x.shape[:-1], 1), nn.functional.one_hot(k, num_classes = self.num_classes).expand(*x.shape[:-1], self.num_classes), s.expand(*x.shape[:-1], 1)])
        return self.net(y)

class ScoreFunc(nn.Module):
    def __init__(self, d = 2, time_dependent = False, **kwargs):
        super(ScoreFunc, self).__init__()
        self.time_dependent = time_dependent
        self.net = nn.MLP(d_in = d+1 if time_dependent else d, d_out = d, **kwargs)
    def forward(self, t, x):
        if self.time_dependent:
            y = torch.hstack([x, t.expand(*x.shape[:-1], 1)])
        else:
            y = x
        return self.net(y)
    
class VectorField(nn.Module):
    def __init__(self, d = 2, time_dependent = False, **kwargs):
        super(VectorField, self).__init__()
        self.time_dependent = time_dependent
        self.net = MLP(d_in = d+1 if time_dependent else d, d_out = d, **kwargs)
    def forward(self, t, x):
        if self.time_dependent:
            return self.net(cat_tx(t, x))
        else:
            return self.net(x)
    def jacobian(self, t, x):
        if self.time_dependent:
            return self.net.jacobian(cat_tx(t, x))[..., :-1]
        else:
            return self.net.jacobian(x)

class _Linear(nn.Module):
    def __init__(self, d = 2):
        super(_Linear, self).__init__()
        self.A=nn.Parameter(torch.zeros(d, d),requires_grad=True)
        self.b=nn.Parameter(torch.zeros(d),requires_grad=True)
    def forward(self, x):
        return x @ self.A.T + self.b
    def sumsq(self):
        return (self.A**2).sum() + (self.b**2).sum()
    def jacobian(self, x):
        return self.A.unsqueeze(0).expand(x.shape[0], *self.A.shape)

class _Quadratic(nn.Module):
    def __init__(self, d = 2):
        super(_Quadratic, self).__init__()
        self.A=nn.Parameter(torch.zeros(d, d),requires_grad=True)
        self.b=nn.Parameter(torch.zeros(d),requires_grad=True)
        self.c=nn.Parameter(torch.scalar_tensor(0.),requires_grad=True)
    def forward(self, x):
        M = (self.A + self.A.T)/2
        Mx = x @ M
        return (Mx * x).sum(-1) + x @ self.b + self.c
        # return (Mx * x).sum(-1) + self.c
    def sumsq(self):
        return (self.A**2).sum() + (self.b**2).sum() + (self.c**2).sum()
        # return (self.A**2).sum() + (self.c**2).sum()
        
class LinearVectorField(nn.Module):
    def __init__(self, d = 2):
        super(LinearVectorField, self).__init__()
        self.net=_Linear(d)
        self.time_dependent = False
    def forward(self, t, x):
        return self.net(x)
    def jacobian(self, t, x):
        return self.net.jacobian(x)

class QuadraticScalarField(nn.Module):
    def __init__(self, d = 2):
        super(QuadraticScalarField, self).__init__()
        self.net=_Quadratic(d)
        self.time_dependent = False
    def forward(self, t, x):
        return self.net(x)
    
class ODEFlowGrowth_linear(nn.Module):
    def __init__(self, d = 3, **kwargs):
        super(ODEFlowGrowth_linear, self).__init__()
        self.v_net = LinearVectorField(d)
        self.g_net = QuadraticScalarField(d)
        self.time_dependent = False
    def forward(self, t, y):
        batchsize = y.shape[0]
        x = y[..., 1:]
        v = self.v_net(t, x)
        g = self.g_net(t, x).unsqueeze(1)
        return torch.hstack([g, v])
    def v(self, t, y):
        x = y[..., 1:]
        return self.v_net(x)
    def g(self, t, y):
        x = y[..., 1:]
        return self.g_net(x)
        
class NGM(nn.Module):
    def __init__(self, hidden_sizes, **kwargs):
        super(NGM, self).__init__()
        self.net = NMC.MLPODEF(hidden_sizes, **kwargs)
    def forward(self, x):
        return self.net(None, x).squeeze()
        
class NGMVectorField(nn.Module):
    def __init__(self, d = 2, hidden_sizes = [64, ], time_dependent = False, **kwargs):
        super(NGMVectorField, self).__init__()
        _hidden_sizes = copy.copy(hidden_sizes)
        _hidden_sizes.insert(0, d)
        _hidden_sizes.append(1)
        if time_dependent:
            print("Error: NGM doesn't support time dependent vector field")
        self.net = NGM(_hidden_sizes, **kwargs)
        self.time_dependent = False
    def forward(self, t, x):
        return self.net(x)
        
class ScalarField(nn.Module):
    def __init__(self, d = 2, time_dependent = False, **kwargs):
        super(ScalarField, self).__init__()
        self.time_dependent = time_dependent
        self.net = MLP(d_in = d+1 if time_dependent else d, d_out = 1, **kwargs)
    def forward(self, t, x):
        if self.time_dependent:
            return self.net(cat_tx(t, x))
        else:
            return self.net(x)
    def grad_x(self, t, x):
        if self.time_dependent:
            jac=torch.func.jacrev(self.net)(cat_tx(t, x))
            return jac[:,:-1]
        else:
            return torch.func.jacrev(self.net)(x)

class ODEFlowGrowth(nn.Module):
    def __init__(self, d = 2, v_mod = VectorField, kwargs_v = {}, g_mod = ScalarField, kwargs_g = {}):
        super(ODEFlowGrowth, self).__init__()
        self.v_net = v_mod(d = d, **kwargs_v)
        self.g_net = g_mod(d = d, **kwargs_g)
    def forward(self, t, y):
        batchsize = y.shape[0]
        x = y[..., 1:]
        v = self.v_net(t, x)
        g = self.g_net(t, x)
        return torch.hstack([g, v])
    def v(self, t, y):
        x = y[..., 1:]
        if self.v_net.time_dependent:
            return self.v_net.net(cat_tx(t, x))
        else:
            return self.v_net.net(x)
    def g(self, t, y):
        x = y[..., 1:]
        if self.g_net.time_dependent:
            return self.g_net.net(cat_tx(t, x))
        else:
            return self.g_net.net(x)
    
class ODEFlowGrowthCoupled(nn.Module):
    def __init__(self, d = 2, **kwargs):
        super(ODEFlowGrowthCoupled, self).__init__()
        self.F_net = ScalarField(d = d, **kwargs)
        self.dF = lambda t, x: torch.vmap(lambda x, _t=t: self.F_net.grad_x(_t, x))(x).squeeze()
    def forward(self, t, y):
        batchsize = y.shape[0]
        x = y[..., 1:]
        v = self.dF(t, x)
        g = self.F_net(t, x)
        return torch.hstack([g, v.squeeze()])
    def v(self, t, y):
        x = y[..., 1:]
        return self.dF(t,  x)
    def g(self, t, y):
        x = y[..., 1:]
        return self.F_net(t, x)

class MultiplicativeNoiseFlow(nn.Module):
    def __init__(self, d, score_model, D, score_sigma,
                 kwargs_u = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False},
                 kwargs_v = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False}):
        super(MultiplicativeNoiseFlow, self).__init__()
        self.d = d
        self.u = VectorField(d = d, **kwargs_u)
        self.v = VectorField(d = d, **kwargs_v)
        self.s = score_model
        self.score_sigma = score_sigma
        self.D = D
        self.div_D_func = torch.vmap(lambda _x: torch.func.jacrev(self.get_diffusion_matrix)(_x).diag())
    def get_diffusion_matrix(self, x):
        return (self.u.net(x) + self.v.net(x)) * self.D/2
    def forward(self, t, x):
        batchsize = x.shape[0]
        u = self.u.net(x)
        v = self.v.net(x)
        s = self.s(t, x, self.score_sigma)
        v_drift = (u - v)
        v_div_D = self.div_D_func(x)
        v_Ds = self.get_diffusion_matrix(x) * s
        return v_drift - v_div_D - v_Ds
    
class MultiplicativeNoiseFlowGrowth(nn.Module):
    def __init__(self, d, score_model, D, score_sigma,
                 kwargs_u = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False},
                 kwargs_v = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False},
                 kwargs_g = {'time_dependent' : False}):
        super(MultiplicativeNoiseFlowGrowth, self).__init__()
        self.d = d
        self.u = VectorField(d = d, **kwargs_u)
        self.v = VectorField(d = d, **kwargs_v)
        self.g = ScalarField(d = d, **kwargs_g)
        self.s = score_model
        self.score_sigma = score_sigma
        self.D = D
        self.div_D_func = torch.vmap(lambda _x: torch.func.jacrev(self.get_diffusion_matrix)(_x).diag())
    def get_diffusion_matrix(self, x):
        return (self.u.net(x) + self.v.net(x)) * self.D/2
    def forward(self, t, y):
        batchsize = y.shape[0]
        x = y[..., 1:]
        u = self.u.net(x)
        v = self.v.net(x)
        g = self.g(t, x)
        s = self.s(t, x, self.score_sigma)
        v_drift = (u - v)
        v_div_D = self.div_D_func(x)
        v_Ds = self.get_diffusion_matrix(x) * s
        w = v_drift - v_div_D - v_Ds
        return torch.hstack([g, w])
        
class LangevinSampler():
    def __init__(self, score_func, x0, n_iter = 1_000, dt = 1e-2, temp_init = 1.0, sigmas = None):
        self.score_func = score_func
        self.n_iter = n_iter
        self.x = x0
        self.dt = dt
        self.noise_cond = False
        if sigmas is not None:
            # noise conditional score network
            self.sigmas = sigmas
            self.noise_cond = True
        else:
            # for simple score networks, still allow for annealing 
            self.temp_init = temp_init
            self.temp_ratio = (1 / temp_init) ** (1 / n_iter)
    def sample(self):
        if self.noise_cond:
            _n_iter_per_noise = math.ceil(self.n_iter / len(self.sigmas))
        with torch.no_grad():
            for iter in range(self.n_iter):
                if self.noise_cond:
                    k = iter // _n_iter_per_noise
                    eta = self.dt * (self.sigmas[k] / self.sigmas[-1])**2
                    _s = self.score_func(self.x, self.sigmas[k])
                else:
                    eta = self.dt * self.temp_init * (self.temp_ratio)**iter
                    _s = self.score_func(self.x)
                self.x += _s * eta + (2*eta)**0.5 * torch.randn_like(self.x)
        return self.x

class SDE(torch.nn.Module):
    noise_type = "diagonal"
    sde_type = "ito"
    def __init__(self, drift, sigma=1.0):
        super().__init__()
        self.drift = drift
        self.sigma = sigma
    def f(self, t, y):
        return self.drift(t, y)
    def g(self, t, y):
        return torch.ones_like(y) * self.sigma

class MultiplicativeNoiseSDE(torch.nn.Module):
    noise_type = "diagonal"
    sde_type = "ito"
    def __init__(self, u, v, sigma=1.0):
        super().__init__()
        self.u = u
        self.v = v
        self.sigma = sigma
    def f(self, t, y):
        return self.u(t, y) - self.v(t, y)
    def g(self, t, y):
        return (self.u(t, y) + self.v(t, y)).sqrt() * self.sigma


