import torch
from torch import normal, Tensor, vmap, Size
from torch.func import jacrev

import abc
import numpy as np
import math
import logging
import functools
import os

logger = logging.getLogger("LOG")

WEIGHT_DIR = os.path.join(os.path.dirname(__file__), "weights")

MODELS = [
    fname for fname in os.listdir(WEIGHT_DIR)
    if os.path.isfile(os.path.join(WEIGHT_DIR, fname))
]

class SDERecorderMeta(abc.ABCMeta, type):
    def __new__(cls, name, bases, class_dict):
        new_class = super().__new__(cls, name, bases, class_dict)
        if name != "SDE":
            MODELS.append(name)
        return new_class

class SDE(metaclass = SDERecorderMeta):
    """Abstract SDE class. 
    Simulates forward and reverse trajectores of an OU process."""

    def __init__(self, N : int, 
                 T : float, 
                 data_shape : int, 
                 noise_schedule = None,
                 device = None, 
                 perturb_size = 0.0):
        """Constructs the SDE simulator.
        
        Args: 
            N: number of discretization steps
            T: simulation time
            data_shape (int): dimension of the latent space
        """
        super().__init__()
        if device == None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)
        logger.info(f"Running on {self.device}")
        logger.debug(f"Instantiated the SDE {self.__class__.__name__}")

        match noise_schedule:
            case None:
                self.schedule_g = self.no_schedule_g
                self.schedule_rate = self.no_schedule_rate
                logger.info(f"Using no noise schedule")
            case 'linear':
                self.schedule_g = self.linear_schedule_g
                self.schedule_rate = self.linear_schedule_rate
                logger.info(f"Using linear noise schedule")
            case 'cosine':
                self.schedule_g = self.cosine_schedule_g
                self.schedule_rate = self.cosine_schedule_rate
                assert T<1, f"Cosine schedule not allowed to start at T= 1. Current T = {T}."
                logger.info(f"Using cosine noise schedule")
            case _:
                logger.info(f"Unrecognized noise schedule {noise_schedule}. Defaulting to no schedule.")
                self.schedule_g = self.no_schedule_g
                self.schedule_rate = self.no_schedule_rate

        self.schedule_alpha = lambda t: torch.exp(-self.schedule_g(t))

        self.perturb_size = perturb_size
        self.T = float(T)
        self.N = N
        self.dt = T/N
        self.data_shape = Size([data_shape])

    @abc.abstractmethod
    def sample_prior(self, n_samples = 1):
        """Generates samples from the approximation of 
        the chosen approximation of p_T

        Args: 
            n_samples: number of samples to generate
        Returns: 
            n_samples samples.
        """
        pass
    
    @abc.abstractmethod
    def score(self, x, t):
        """Evaluates the score at the vector of inputs x and time t

        MUST IMPLEMENT THE SCORE FOR THE STANDARD OU

        Args:
            x: a list of N inputs
            t: time
        Returns:
            a list of N scores.
        """
        pass

    def perturbation(self, x, t):
        return torch.zeros(self.data_shape).to(self.device)
    
    def Dperturbation(self, x, t):
        return torch.zeros((self.data_shape, self.data_shape)).to(self.device)

    
    def Dscore(self, x, t):
        """Evaluates the derivative of the score at the vector of inputs x and time t

        Args:
            x: imput of dimension self.data_shape
            t: time
        Returns:
            outputs a matrix of size self.data_shape*self.data_shape
        """
    
        assert x.shape == self.data_shape
        f = functools.partial(self.score, t=t)
        return jacrev(f)(x)
    
    def Dperturbation(self, x, t):
        f = functools.partial(self.perturbation, t=t)
        return jacrev(f)(x)

    # @abc.abstractmethod
    # def pdf(self, x, t):
    #     """Evaluates the pdf at the vector of inputs x and time t

    #     Args:
    #         x: 
    #         t: time
    #     Returns:
    #         a list of N values.
    #     """

    # @abc.abstractmethod
    # def log_pdf(self, x, t):
    #     """Evaluates the log of pdf at the vector of inputs x and time t

    #     Args:
    #         x: a list of N inputs
    #         t: time
    #     Returns:
    #         a list of N values.
    #     """

    #TODO: DRY THIS
    def batched_score(self, x, t):
        f = functools.partial(self.score, t = t)
        return vmap(vmap(f, randomness='same'), randomness='same')(x)
    
    def batched_Dscore(self, x, t):
        f = functools.partial(self.Dscore, t = t)
        return vmap(vmap(f, randomness='same'), randomness='same')(x)
    
    def batched_perturb(self, x, t):
        f = functools.partial(self.perturbation, t = t)
        return vmap(vmap(f))(x)
    
    def batched_Dperturb(self, x, t):
        f = functools.partial(self.Dperturbation, t = t)
        return vmap(vmap(f))(x)
    
    def sample_forward_trajectory(self, inits = None):
        """Generates sample trajectories from forward process using Euler Maruyama.

        Args: 
            inits: initial conditions to diffuse
        Returns: 
            n_samples sample trajectories.
        """

        if inits.shape == self.data_shape:
            inits = inits.unsqueeze(0)

        trajectories = torch.zeros(self.N, *inits.shape)
        trajectories[0] = inits
        
        for i in range(self.N-1):
            trajectories[i+1] =(1 - 0.5*self.dt) * trajectories[0] + \
                np.sqrt(self.dt)*normal(0,1,self.data_shape)
        return trajectories
        

    def sample_reverse_trajectory(self, n_samples = 1, n_noise_realizations = 1, inits = None):
        """Generates sample trajectories from reverse process using Euler Maruyama. 

        Args: 
            inits: initial conditions to reverse
            n_noise_realizations: if inits = None, equals number of distinct brownian motion 
                paths to sample from.
            n_samples: if inits = None, number of distinct trajectories per noise realization
        Returns: 
            tensor of size (N, n_noise_realizations, n_samples, data_shape) representing
                 sample trajectories.
        """
        with torch.no_grad():
            if not inits:
                inits = torch.zeros(n_noise_realizations, n_samples, self.data_shape[0])
                for k in range(n_samples):
                    inits[:, k, :] = self.sample_prior(n_samples=n_noise_realizations)

            inits = Tensor(inits)
            if inits.shape == self.data_shape:
                inits = inits.unsqueeze(0)
                inits = inits.unsqueeze(0)
            n_samples = inits.shape[1]

            trajectories = torch.zeros(self.N, 
                                    n_noise_realizations, 
                                    n_samples, 
                                    self.data_shape[0]).to(self.device)
            trajectories[-1] = inits
            t = self.T        

            for i in range(self.N-1):
                if i % 100 == 0:
                    logger.info(f"  Timestep {i}/{self.N}")

                g = self.schedule_g(t).to(self.device)
                rate = self.schedule_rate(t).to(self.device)

                trajectories[self.N-i-2] =trajectories[self.N-i-1] + \
                    (0.5*trajectories[self.N-i-1] + self.batched_score(trajectories[self.N-i-1], g))*self.dt*rate + \
                    self.perturb_size*self.batched_perturb(trajectories[self.N-i-1], g)*self.dt*rate +\
                    torch.sqrt(self.dt*rate)*normal(0,1, (n_noise_realizations, 1, self.data_shape[0]), device=self.device)
                t = t-self.dt

        return trajectories.detach().cpu()
    
    def generate_lyap_spectrum(self, trajectories):

        with torch.no_grad():
            trajectories = trajectories.to(self.device)
            n_samples = trajectories.shape[2]
            lexps = torch.zeros(*trajectories.shape).to(self.device)
            # lvects = torch.zeros(*trajectories.shape, self.data_shape[0])
            # lvects[-1] = torch.stack([torch.eye(self.data_shape[0]) for _ in range(n_samples)])
            # lvects = lvects.to(self.device)

            # compute lyapunov spectrum backwards
            t = self.T
            eye = torch.eye(self.data_shape[0]).to(self.device)
            Q = eye
            for i in range(self.N-1):
                if i % 100 == 0:
                    logger.info(f"  Timestep {i}/{self.N}")
                    
                g = self.schedule_g(t).to(self.device)
                rate = self.schedule_rate(t).to(self.device)

                M = (eye + \
                    (0.5*eye + self.batched_Dscore(trajectories[self.N - i - 1], g))*self.dt*rate + \
                    self.perturb_size*self.batched_Dperturb(trajectories[self.N - i - 1], g)*self.dt*rate
                    )@Q
                
                Q, R = torch.linalg.qr(M)
                diags = torch.diagonal(R, dim1 = 2, dim2 = 3)
                diag_abs = torch.abs(diags)
                # sorted, indices = torch.sort(diag_abs, dim = 2, descending=True)
                # Q = Q[i,j, indices[i,j]] for i, j in range(n_samples), range(n_noise)
                # lvects[self.N - i - 2] = Q
                lexps[self.N-i-2] = lexps[self.N-i-1] + torch.log(diag_abs)

                t = t-self.dt

            scales = torch.arange(0,self.N, dtype = torch.float32)*self.dt
            scales[1:] = 1.0/scales[1:]
            scales = scales.flip(0)
            scales = scales.to(self.device)

            lexps = lexps*scales.view(self.N, 1,1,1)

            logger.info('Sorting Lyapunov Exponents by final value.')
            sorted, indices = torch.sort(lexps[0], dim = -1, descending=True)
            logger.debug('indices.shape: ' + str(indices.shape))
            logger.debug('lexps.shape: ' + str(lexps.shape))

            lexps_idx = indices[None, :, :, :]
            lexps_idx = lexps_idx.expand(self.N, -1, -1, -1)
            lexps = torch.gather(lexps, dim = -1, index = lexps_idx)

            lvec_idx = indices[:, :, None, :]
            lvec_idx = lvec_idx.expand(-1, -1, self.data_shape[0], -1)
            Q = torch.gather(Q, dim = -1, index = lvec_idx)

        return lexps.detach().cpu(), Q.detach().cpu()
    
    # def D_reverse_trajectory(self, trajectories):

    #     trajectories = trajectories.to(self.device)
    #     n_samples = trajectories.shape[2]
    #     lexps = torch.zeros(*trajectories.shape).to(self.device)
    #     lvects = torch.zeros(*trajectories.shape, self.data_shape[0])
    #     lvects[-1] = torch.stack([torch.eye(self.data_shape[0]) for _ in range(n_samples)])
    #     lvects = lvects.to(self.device)

    #     # compute lyapunov spectrum backwards
    #     t = self.T
    #     eye = torch.eye(self.data_shape[0]).to(self.device)
    #     for i in range(self.N-1):
    #         if i % 100 == 0:
    #             logger.info(f"  Timestep {i}/{self.N}")
                
    #         M = ((1+0.5*self.dt)*eye + \
    #             self.batched_Dscore(trajectories[self.N - i - 1], t)*self.dt + \
    #             self.perturb_size*self.batched_Dperturb(trajectories[self.N - i - 1], t)
    #             )@lvects[self.N - i - 1]
            
    #         Q, R = torch.linalg.qr(M)
    #         lvects[self.N - i - 2] = Q
    #         diags = torch.diagonal(R, dim1 = 2, dim2 = 3)
    #         lexps[self.N-i-2] = lexps[self.N-i-1] + torch.log(torch.abs(diags))

    #         t = t-self.dt

    #     scales = torch.arange(0,self.N, dtype = torch.float32)*self.dt
    #     scales[1:] = 1.0/scales[1:]
    #     scales = scales.flip(0)
    #     scales = scales.to(self.device)

    #     lexps = lexps*scales.view(self.N, 1,1,1)

    #     return lexps.detach().cpu(), lvects.detach().cpu()

    
    def sample_target(self, inits = None, n_samples = 1):
        trajs = self.sample_reverse_trajectory(1, n_samples, None)
        return trajs[0]
    
    @staticmethod
    def linear_schedule_rate(t, a = 1e-4, b = 1):
        if not isinstance(t, torch.Tensor):
            t = torch.tensor(t)
        return a + (b-a)*t
    
    @staticmethod
    def linear_schedule_g(t, a = 1e-4, b = 1):
        if not isinstance(t, torch.Tensor):
            t = torch.tensor(t)
        return a*t + (b-a)/2*t**2
    
    @staticmethod
    def cosine_schedule_rate(t, delta = 1e-4):
        if not isinstance(t, torch.Tensor):
            t = torch.tensor(t)
        args = (t+delta)/(1+delta)*torch.pi/2 
        return torch.pi/(1+delta)*torch.sin(args)/torch.cos(args)
    
    @staticmethod
    def cosine_schedule_g(t, delta = 1e-4):
        if not isinstance(t, torch.Tensor):
            t = torch.tensor(t)
        args = (t+delta)/(1+delta)*torch.pi/2
        c = math.cos((delta)/(1+delta)*torch.pi/2)**2
        return -torch.log(torch.cos(args)**2/c)
    
    @staticmethod
    def no_schedule_rate(t):
        if not isinstance(t, torch.Tensor):
            t = torch.tensor(t)
        return torch.ones_like(t)
    
    @staticmethod
    def no_schedule_g(t):
        if not isinstance(t, torch.Tensor):
            t = torch.tensor(t)
        return t
    

    