import numpy as np 
import torch
from torch.autograd import grad

##################################################################################################################################################################################################################
##################################################################################################################################################################################################################
# 1-dimensional periodic dynamical systems
##################################################################################################################################################################################################################
##################################################################################################################################################################################################################


def sawteeth(x): 
    return 2*(x % 1.0)-1

def elt_square(x):
        y = np.ones_like(x)
        y[x>=0.5] = -1
        return y

def square(x): 
    return elt_square(x % 1.0)

def square_approx(n,x): 
    arr = 2*np.arange(1,n+1).reshape(-1,1) -1
    return 4/np.pi*np.sum(np.sin(2*np.pi*arr*x)/arr,axis=0)

def elt_triangle(x): 
    y = np.zeros_like(x)
    y[x<0.5]= 4*x[x<0.5]-1
    y[x>=0.5] = 1 - 4*(x[x>=0.5]-0.5)
    return y

def triangle(x): 
    return elt_triangle(x % 1.0)

def triangle_approx(n,x): 
    arr = 2*np.arange(1,n+1).reshape(-1,1) -1
    return 8/np.pi**2 * np.sum(np.cos(2*np.pi*arr*(x-0.25))/arr**2,axis=0)

def sin(x): 
    return np.sin(2*np.pi*x)

def cos(x):
    return np.cos(2*np.pi*x)

def cos_to_square(power,x): 
    base = np.cos(2*np.pi*x)
    return np.sign(base)* np.abs(base)**(1/power)


def periodic_ts_1d(lst:list,n_samples:int,sampfreq=100,sigma=1e-3)-> np.ndarray:
    """Generate dynamical systems data based on periodic functions and with additive gaussian noise.

    Args:
        lst (list of list): Each individual list contains in order:  an amplitude (float), a decay rate (float), the phase (float), the frequency (float), and the periodic function (callable).
        n_samples (int): number of samples.
        sampfreq (int, optional): Sampling frequency in Hetz. Defaults to 100.
        sigma (float, optional): standard deviation of the additive mean centred Gaussian noise. Defaults to 1e-3.

    Returns:
        np.ndarray: generated data.
    """
    data = np.zeros(n_samples) 
    time = np.arange(n_samples,dtype=float)/sampfreq
    for amp,decay,phase,freq,func in lst:
        data+= amp*func(phase+time*freq)* np.exp(-decay*time)
    data += sigma*np.random.randn(n_samples)
    return data[:,None]  # Ensure the output is a 2D array with shape (n_samples, 1)

def periodic_ds(lst:list,n_samples:int,sampfreq=100,sigma=1e-3)-> np.ndarray:
    """Generate dynamical systems data based on periodic functions and with additive gaussian noise.

    Args:
        lst (list of list): Each individual list is a triplet with initial phase (ndarray),the periodic function (callable), the number of occurence per second.
        n_samples (int): number of samples.
        sampfreq (int, optional): Sampling frequency in Hetz. Defaults to 100.
        sigma (float, optional): standard deviation of the additive mean centred Gaussian noise. Defaults to 1e-3.

    Returns:
        np.ndarray: generated data.
    """
    shape = (*lst[0][0].shape,n_samples)
    data = np.zeros(shape) 
    time = np.arange(n_samples,dtype=float)/sampfreq
    for phase,func,freq in lst:
        data+= func(phase[...,None]+time*freq)
    data += sigma*np.random.randn(*shape)
    return data.T

def generate_1d_lst(func:callable, n_comp:int=1,min_freq:float=1, max_freq:float=1, min_amp:float=1,max_amp:float=1,min_offset:float=0,max_offset:float =1,replace=False) -> list:
    # the period is given by min_freq
    if n_comp>1:
        freqs = np.arange(min_freq, max_freq, min_freq)[1:]
        if freqs.shape[0] < n_comp:
           raise ValueError(f"n_comp {n_comp} is too large for the range of frequencies {min_freq} to {max_freq}")
        else:
           freqs = np.random.choice(freqs,n_comp-1,replace=replace)
        freqs = np.concatenate(([min_freq], freqs))
    else:
        freqs = np.array([min_freq])
    amps = np.random.rand(n_comp) * (max_amp - min_amp) + min_amp
    offsets = np.random.rand(n_comp) * (max_offset - min_offset) + min_offset
    lst = []
    for freq,amp,offset in zip(freqs,amps,offsets):
       tlst = [np.array([offset]),lambda x: amp * func(x),freq]
       lst.append(tlst)
    return lst

##################################################################################################################################################################################################################
##################################################################################################################################################################################################################
# N-dimensional linear dynamical system
##################################################################################################################################################################################################################
##################################################################################################################################################################################################################


class LinearSystem: 

    def __init__(self,operator:np.ndarray,mapping:np.ndarray=None,noise_std:float=None,complex_to_real:bool=True):
        #refine the checking
        assert (operator.ndim == 2)*(operator.shape[0] == operator.shape[1]), "The operator must be a square matrix."
        if complex_to_real: 
            assert operator.shape[0]%2==0, "If complex_to_real True, the rank of the operator must be even."
        self.operator = operator
        self.ndim = operator.shape[0]
        self.noise_std = noise_std
        self.complex_to_real = complex_to_real

        if mapping is None:
            if complex_to_real: 
                self.mapping = np.hstack((np.eye(self.ndim//2),np.eye(self.ndim//2)))
            else:
                self.mapping = np.eye(operator.shape[0])
        else: 
            self.mapping = mapping

        if noise_std is None: 
            self.noisy = False
        else: 
            self.noisy = True

    def _generate_noise(self): 
        return np.random.randn(self.ndim)*self.noise_std
    
    def _generate_noise(self):
        if self.complex_to_real:
            noise = np.random.randn(self.ndim //2)*self.noise_std
            noise = np.hstack((noise,noise))
        else: 
            np.random.randn(self.ndim)*self.noise_std
        return noise
    
    def _step(self,x): 
        x_next = self.operator @ x
        if self.noisy: 
            x_next += self._generate_noise()
        return x_next
    
    def trajectory(self,n_samples,x_init):
        assert x_init.shape[0] == self.ndim, "x_init must be of the same rank as the systems"
        lst = [x_init]
        x = x_init
        for i in range(n_samples-1): 
            x = self._step(x)
            lst.append(x)
        latent_traj = np.array(lst)
        traj = (self.mapping @ latent_traj.T).T
        if self.complex_to_real: 
            traj = traj.real
        return traj, latent_traj


class DiagonalSystem(LinearSystem): 

    def __init__(self, P, eigval, P_inv, mapping,noise_std = None,complex_to_real=True):
        self.P = P
        self.eigval = eigval
        self.P_inv = P_inv
        operator = P @ np.diag(eigval) @ P_inv
        super().__init__(operator, mapping, noise_std,complex_to_real)

    @property
    def decay_(self): 
        return np.log(self.eigval).real
    
    @property
    def frequency_(self): 
        return np.log(self.eigval).imag
    

class DiagonalSystemGenerator(object):

    def __init__(self,nsig:int,rank:int,sampfreq:int,maxfreq:float=None,minfreq:float=0,maxdecay:float=1e-1,noise_std:float=None,eps:float=1e-3):
        assert rank%2==0, "rank must be even."
        assert rank>=2*nsig, "rank must bet at least twice the number of signal"
        assert 2*maxfreq<= sampfreq, "maximum frequency must be below the Niquist frequency."
        self.nsig = nsig
        self.rank = rank
        self.sampfreq = sampfreq
        if maxfreq is None: 
            self.maxfreq = sampfreq //2
        else: 
            self.maxfreq = maxfreq
        self.minfreq = minfreq
        self.maxdecay = maxdecay
        self.eps = eps
        self.noise_std = noise_std

        if noise_std is None: 
            self.noisy = False
        else: 
            self.noisy = True
        self.fitted_ = False

    def _generate_mapping(self): 
        to_real = np.hstack((np.eye(self.rank//2),np.eye(self.rank//2)))
        midrank = self.rank//2
        if self.nsig == midrank:
            real_to_dim = np.eye(self.nsig)
        else: 
            real_to_dim = np.zeros((self.nsig,midrank))
            base = np.random.choice(midrank,self.nsig,replace=False)
            real_to_dim[np.arange(self.nsig),base] = 1
            missing = np.arange(midrank)[~np.isin(np.arange(midrank),base)]
            dims = np.random.choice(self.nsig,missing.shape[0],replace = True)
            real_to_dim[dims,missing] = 1
        mapping = real_to_dim @ to_real
        return mapping

    def _generate_invertible_matrices(self):
        P = np.random.randn(self.rank//2,self.rank) + 1j *np.random.randn(self.rank//2,self.rank)
        P = np.eye(self.rank//2,self.rank) - P / ((1 +self.eps)*np.linalg.norm(P))
        P = np.vstack((P,np.conjugate(P)))
        P_inv = P /np.linalg.norm(P,axis=0).reshape(1,-1)
        P = np.linalg.solve(P_inv,np.eye(self.rank))
        return P,P_inv

    def _generate_eigenvalues(self): 
        self.freqs_ = np.random.rand(self.rank//2)*(self.maxfreq-self.minfreq) + self.minfreq
        self.decays_ = np.random.rand(self.rank//2)*self.maxdecay
        eigvals = np.exp((-self.decays_ + 1j * 2*np.pi * self.freqs_) /self.sampfreq )
        eigvals = np.hstack((eigvals,np.conjugate(eigvals)))
        return eigvals
    
    def fit(self): 
        P,P_inv = self._generate_invertible_matrices()
        eigvals = self._generate_eigenvalues()
        mapping = self._generate_mapping()
        self.base_operator_ = DiagonalSystem(P,eigvals,P_inv,mapping,self.noise_std)
        self.fitted_ = True


def matrices_complexe_perturbation(ds:DiagonalSystem,pertubation_norm:float): 

    if ds.complex_to_real: 
        perturbation = np.random.randn(ds.ndim//2,ds.ndim)
        perturbation = np.vstack((perturbation,np.conjugate(perturbation)))
    else: 
        perturbation = np.random.randn(*ds.P.shape) + 1j * np.random.randn(*ds.P.shape)

    perturbation *= pertubation_norm/np.linalg.norm(perturbation)*np.linalg.norm(ds.P_inv)
    P_inv_per = ds.P_inv + perturbation
    P_per = np.linalg.solve(P_inv_per,np.eye(ds.ndim))

    ds_per = DiagonalSystem(P_per,ds.eigval,P_inv_per,ds.mapping,ds.noise_std,ds.complex_to_real)

    return ds_per

##################################################################################################################################################################################################################
##################################################################################################################################################################################################################
# Langevin generator
##################################################################################################################################################################################################################
##################################################################################################################################################################################################################

def two_well_1d_potential(amplitude:float=1.0,minimum:float=1.0):
    """Creates a 1D two-well potential function.

    Args:
        amplitude (float, optional): Amplitude of the potential. Defaults to 1.0.
        minimum (float, optional): Position of the minimum of the potential. Defaults to 1.0.

    Returns:
        callable: A function representing the potential.
    """
    def f(x): 
        lmbda = amplitude*4/minimum**4
        return lmbda * x**2 *(x**2/4 - minimum**2/2)

    return f

def n_dimensional_gaussian(mean, cov):
    """
    Computes the value of an n-dimensional Gaussian distribution at points x.

    Args:
        mean (torch.Tensor): Mean vector, shape (ndim,).
        cov (torch.Tensor): Covariance matrix, shape (ndim, ndim).

    Returns:
       callable: A function that takes a tensor x and returns the Gaussian value at those points.
    """
    n = mean.shape[0]
    inv_cov = torch.linalg.inv(cov)
    norm_const = torch.sqrt((2 * torch.pi) ** n * torch.linalg.det(cov))
    def f(x):
        # Ensure x is a 2D tensor with shape (n_samples, ndim)
        diff = x - mean
        exponent = -0.5 * torch.sum(diff @ inv_cov * diff, dim=-1)
        return torch.exp(exponent) / norm_const
    return f


def gaussian_mixture(cov_lst: list, mean_lst: list, ponderation:torch.Tensor=None,dtype=torch.float32)-> callable: 
    """Creates a potential function for multiple Gaussian wells in n-dimensional space.

    Args:
        cov_lst (list): List of torch.tensor covariance matrices for each well.
        mean_lst (list): List of torch.tensor mean positions for each well.

    Returns:
        callable: A function representing the potential.
    """
    if ponderation is None:
        ponderation = torch.ones(len(cov_lst), dtype=dtype)/len(cov_lst)
    f_lst = [n_dimensional_gaussian(mean.type(dtype), cov.type(dtype)) for mean, cov in zip(mean_lst, cov_lst)]
    def f(x:torch.Tensor):
        res = 0
        for g, p in zip(f_lst, ponderation):
            res += p * g(x)
        return res
    return f


def n_dimensional_wells_potential(cov_lst: list, mean_lst: list)-> callable: 
    """Creates a potential function for multiple Gaussian wells in n-dimensional space.

    Args:
        cov_lst (list): List of torch.tensor covariance matrices for each well.
        mean_lst (list): List of torch.tensor mean positions for each well.

    Returns:
        callable: A function representing the potential.
    """
    def f(x:torch.Tensor):
        res = 0
        for cov,mean in zip(cov_lst,mean_lst):
            diff = x - mean
            res -= torch.exp(-0.5 * torch.sum(diff @ cov * diff, dim=-1))
        return res
    return f
class LangevinGenerator(object):

    def __init__(self, potential: callable, brown_scale: float, sampfreq: int, device: str = "cpu"):
        self.potential = potential
        self.brown_scale = brown_scale
        self.sampfreq = sampfreq
        self.device = device

    def _grad(self, X: torch.Tensor):
        X = X.clone().detach().requires_grad_(True)
        potential_val = self.potential(X)
        gradient = grad(potential_val, X, create_graph=False, retain_graph=False)[0]
        return gradient.detach()

    def _one_step(self, X: torch.Tensor):
        X = X.clone().detach()
        grad_val = self._grad(X)
        noise = torch.sqrt(torch.tensor(2 * self.brown_scale, dtype=X.dtype, device=X.device)) * \
                torch.randn_like(X) / self.sampfreq
        X_new = X - grad_val / self.sampfreq + noise
        # Clamp to avoid runaway values
        X_new = torch.clamp(X_new, -1e6, 1e6)
        # Optionally, check for NaN/Inf
        if torch.isnan(X_new).any() or torch.isinf(X_new).any():
            print("NaN or Inf detected in new sample, resetting to zero.")
            X_new = torch.zeros_like(X_new)
        return X_new

    def sample_trajectory(self, X_init: torch.Tensor, duration: float):
        n_samples = int(duration * self.sampfreq)
        n_dim = X_init.shape[0]
        traj = torch.empty((n_samples, n_dim), device=self.device, dtype=X_init.dtype)
        X = X_init.to(self.device).clone().detach()
        traj[0] = X
        for i in range(1, n_samples):
            X = self._one_step(X)
            traj[i] = X
        return traj