from .base_sampler import SDE
from .base_trained import Trained

import torch
from functorch import jacrev
import torch.nn.init as init

class ScoreModel2D(Trained):
    """Simple MLP for approximating the score."""

    def __init__(self, 
                 name, 
                 hidden_dim = 128, 
                 n_layers = 5, 
                 randomize_weights = False, 
                 *args, **kwargs):
        super().__init__(name, randomize_weights = randomize_weights, *args, **kwargs)
        
        layers = []
        layers.append(torch.nn.Linear(3, hidden_dim))
        layers.append(torch.nn.ReLU())
        for i in range(n_layers):
            layers.append(torch.nn.Linear(hidden_dim,hidden_dim))
            layers.append(torch.nn.ReLU())
            
        layers.append(torch.nn.Linear(hidden_dim,2)) 
        self.sequence = torch.nn.Sequential(*layers)
        # self.init_weights()

    def forward(self, x, t):
        l = torch.log(t)
        # singularity = 1/(1-torch.exp(-t))
        # F = lambda x, l : -0.5*x + 1/(torch.sqrt(1-torch.exp(-t)))*self.sequence(torch.cat((x, l), dim = -1))
        # p = lambda x, l : (F(x, l)**2).sum()
        # return vmap(grad(p, argnums= 0))(x, l)
        l_ = torch.full((*x.shape[:-1], 1), l[0], device = x.device)
        return self.sequence(torch.cat((x, l_), dim = -1))
    
    def init_weights(self):
        self.trained = False
        for m in self.modules():
            if isinstance(m, torch.nn.Linear):
                init.xavier_uniform_(m.weight)
                init.zeros__(m.bias)


class Trained2D(SDE):
    """Implements the model from trained data using a simple MLP to approximate the 
    score.
    """
    def __init__(self, N, T, model_name, 
                 hidden_dim = 128, 
                 n_layers = 5,
                 device = None,
                 noise_schedule = None,
                 perturb_size = 0.0):
        super().__init__(N, T, 2, device=device, noise_schedule=noise_schedule, perturb_size=perturb_size)

        self.score_model = ScoreModel2D(name = model_name, 
                                        hidden_dim=hidden_dim,
                                        n_layers=n_layers)
        self.score_model.to(self.device)
        self.score_model.eval()


    def sample_prior(self, n_samples=1):
        return torch.randn((n_samples, self.data_shape[0]))
    
    def score(self, x, t):
        t = torch.Tensor([t]).to(self.device)
        return 1/(torch.sqrt(1-torch.exp(-t)))*self.score_model(x.unsqueeze(0),t)[0]

    # def Dscore(self, x, t):
    #     return jacrev(self.score, argnums=0)(x,t)
        
    def perturbation(self, x, t):
        R = torch.Tensor([[0, 1], [-1, 0.]]).to(device=self.device)
        v = R@x
        return v
    
    def Dperturbation(self, x, t):
        Dv = torch.Tensor([[0.0, 1],[-1, 0.0]]).to(self.device)
        return Dv

    