import numpy as np 
import torch 
import math
from abc import ABC, abstractmethod

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(f"Using device: {device}")

class SyntheticDataGenerator(ABC):
    def __init__(self, T_past: int, T_fut: int, batch_size: int):
        self.T_past = T_past
        self.T_fut = T_fut
        self.batch_size = batch_size

    @abstractmethod
    def sample(self):
        pass

    @abstractmethod
    def get_mutual_information(self):
        pass

class SyntheticSequenceGenerator(ABC):
    def __init__(self,batch_size: int):
        self.batch_size = batch_size

    @abstractmethod
    def generate_long_array(self, N: int):
        pass

class AutoRegressiveGenerator(SyntheticSequenceGenerator):
    def __init__(self, batch_size: int, p: int, rho: float, dim: int = 1, device=device):
        super().__init__(batch_size)
        if not (1 <= p <= 300):
            raise ValueError("L'ordre p doit être compris entre 1 et 10")
        self.p = p
        self.rho = rho
        self.dim = dim
        self.device = device
        if self.p > 1:
            self.ar_coeffs = torch.tensor([self.rho / self.p] * self.p, device=self.device)
        else:
            self.ar_coeffs = torch.tensor([self.rho], device=self.device)
            
    def generate_long_array(self, N: int):
        z = torch.zeros((N, self.dim), device=self.device)
        z[:self.p] = torch.randn((self.p, self.dim), device=self.device)
        eps = torch.randn((N, self.dim), device=self.device)
        
        for i in range(self.p, N):
            if self.p == 1:
                z[i] = self.rho * z[i-1] + torch.sqrt(torch.tensor(1 - self.rho**2, device=self.device)) * eps[i]
            else:
                update = torch.sum(self.ar_coeffs.reshape(-1, 1) * z[i-self.p:i], dim=0)
                z[i] = update + torch.sqrt(torch.tensor(1 - self.rho**2, device=self.device)) * eps[i]
               
        return z.cpu().numpy()

class MultivariateMarkovGenerator(SyntheticSequenceGenerator):
    def __init__(self, batch_size: int, correlation_matrix: np.ndarray, dim: int):
        super().__init__(batch_size)
        self.correlation_matrix = torch.tensor(correlation_matrix, device=device)
        self.dim = dim
        
        assert self.correlation_matrix.shape == (dim, dim)
        assert torch.allclose(self.correlation_matrix, self.correlation_matrix.T)
        eigenvals = torch.linalg.eigvals(self.correlation_matrix)
        assert torch.all(eigenvals > 0), "Correlation matrix must be positive definite"
        
        self.L = torch.linalg.cholesky(self.correlation_matrix)
        
    def generate_long_array(self, N: int):
        z = torch.zeros((N, self.dim), device=device)
        z[0] = torch.normal(torch.zeros(self.dim, device=device), torch.eye(self.dim, device=device))
        
        for i in range(1, N):
            eps = torch.normal(torch.zeros(self.dim, device=device), torch.eye(self.dim, device=device))
            correlated_noise = self.L @ eps
            
            z[i] = (self.correlation_matrix @ z[i-1] + 
                   torch.linalg.cholesky(torch.eye(self.dim, device=device) - self.correlation_matrix @ self.correlation_matrix.T) @ correlated_noise)
        return z.cpu().numpy()


class MarkovGaussianGenerator(SyntheticDataGenerator):
    def __init__(self, T_past: int, T_fut: int, batch_size: int, rho: float, dim: int = 1):
        super().__init__(T_past, T_fut, batch_size)
        self.rho = rho
        self.dim = dim

    def sample(self):
        N = self.T_past + self.T_fut
        z = torch.zeros((self.batch_size, N, self.dim), device=device)
        z[:, 0, :] = torch.randn((self.batch_size, self.dim), device=device)
        
        eps = torch.randn((self.batch_size, N, self.dim), device=device)
        for i in range(1, N):
            z[:, i, :] = self.rho * z[:, i-1, :] + torch.sqrt(1 - self.rho**2) * eps[:, i, :]
        
        X_past = z[:, :self.T_past, :]
        X_fut = z[:, self.T_past:, :]
        
        return X_past, X_fut

    def get_mutual_information(self):
        return (-self.dim/2) * np.log(1 - self.rho**2)


class KernelBasedGenerator(SyntheticDataGenerator):
    def __init__(self, T_past: int, T_fut: int, batch_size: int, kernel_type: str, 
                 kernel_params: dict, data_dim: int = 1):
        super().__init__(T_past, T_fut, batch_size)
        self.kernel_type = kernel_type
        self.kernel_params = kernel_params
        self.data_dim = data_dim
        self._setup_kernel()

    def _setup_kernel(self):
        if self.kernel_type == "AR":
            self.kernel = lambda t1, t2: (
                self.kernel_params["sigma"]**2 * 
                self.kernel_params["rho"]**abs(t1-t2)
            )
        elif self.kernel_type == "matern32":
            self.kernel = lambda t1, t2: (
                self.kernel_params["sigma"]**2 * 
                (1 + math.sqrt(3)*abs(t1-t2)/self.kernel_params["l"]) * 
                math.exp(-math.sqrt(3)*abs(t1-t2)/self.kernel_params["l"])
            )
        elif self.kernel_type == "matern52":
            self.kernel = lambda t1, t2: (
                self.kernel_params["sigma"]**2 * 
                (1 + math.sqrt(5)*abs(t1-t2)/self.kernel_params["l"] + 
                 5*abs(t1-t2)**2/(3*self.kernel_params["l"]**2)) * 
                math.exp(-math.sqrt(5)*abs(t1-t2)/self.kernel_params["l"])
            )
        elif self.kernel_type == "squared_exponential":
            self.kernel = lambda t1, t2: (
                self.kernel_params["sigma"]**2 * 
                math.exp(-0.5 * (abs(t1-t2)/self.kernel_params["l"])**2)
            )
        elif self.kernel_type == "periodic":
            self.kernel = lambda t1, t2: (
                self.kernel_params["sigma"]**2 * 
                math.exp(-2 * math.sin(math.pi * abs(t1-t2)/self.kernel_params["period"])**2 / 
                       self.kernel_params["l"]**2)
            )
        elif self.kernel_type == "rational_quadratic":
            self.kernel = lambda t1, t2: (
                self.kernel_params["sigma"]**2 * 
                (1 + abs(t1-t2)**2/(2*self.kernel_params["alpha"]*self.kernel_params["l"]**2))**(-self.kernel_params["alpha"])
            )
        elif self.kernel_type == "locally_periodic":
            self.kernel = lambda t1, t2: (
                self.kernel_params["sigma"]**2 * 
                math.exp(-2 * math.sin(math.pi * abs(t1-t2)/self.kernel_params["period"])**2 / 
                       self.kernel_params["l"]**2) *
                math.exp(-0.5 * (abs(t1-t2)/self.kernel_params["decay"])**2)
            )
        else:
            raise ValueError(f"Unknown kernel type: {self.kernel_type}")

    def _generate_covariance(self, T_total):
        times = np.arange(T_total)
        Sigma = torch.empty((T_total, T_total), device=device)
        for i in range(T_total):
            for j in range(T_total):
                Sigma[i, j] = self.kernel(times[i], times[j])
        Sigma += 1e-6 * torch.eye(T_total, device=device)
        return Sigma

    def sample(self):
        T_total = self.T_past + self.T_fut
        Sigma = self._generate_covariance(T_total)
        mean = torch.zeros(T_total, device=device)
        
        if self.data_dim > 1:
            samples = torch.stack([
                torch.distributions.MultivariateNormal(mean, Sigma).sample((self.batch_size,)) 
                for _ in range(self.data_dim)
            ], dim=-1)
        else:
            samples = torch.distributions.MultivariateNormal(mean, Sigma).sample((self.batch_size,)).unsqueeze(-1)
        
        X_past = samples[:, :self.T_past, :]
        X_fut = samples[:, self.T_past:, :]
        
        return X_past, X_fut

    def get_mutual_information(self):
        T_total = self.T_past + self.T_fut
        Sigma = self._generate_covariance(T_total).cpu().numpy()
        
        Sigma_past = Sigma[:self.T_past, :self.T_past]
        Sigma_fut = Sigma[self.T_past:, self.T_past:]
        
        det_past = np.linalg.det(Sigma_past)
        det_fut = np.linalg.det(Sigma_fut)
        det_joint = np.linalg.det(Sigma)
        
        mi_1d = 0.5 * np.log((det_past * det_fut) / det_joint)
        return mi_1d * self.data_dim


class EvoRateGenerator(SyntheticDataGenerator):
    def __init__(self, T_past: int, T_fut: int, batch_size: int, rho: float, dim: int = 1):
        super().__init__(T_past, T_fut, batch_size)
        self.rho = rho
        self.dim = dim

    def sample(self):
        T = self.T_past + self.T_fut
        eps = torch.randn((T * self.batch_size, self.dim), device=device).view(self.batch_size, T, self.dim)
        x = torch.empty((self.batch_size, T-1, self.dim), device=device).float()
        
        for i in range(T-1):
            x[:, i] = eps[:, i] - 0.5
            
        y = (torch.sqrt(1 - self.rho**2) * eps[:, -1] + 
             self.rho * torch.sum(x, dim=1) / torch.sqrt(T-1) + 1)
        
        return x, y.unsqueeze(1)

    def get_mutual_information(self):
        return (-self.dim/2) * np.log(1 - self.rho**2)






if __name__ == "__main__":
    # Example 1: Markov Gaussian process
    markov_gen = MarkovGaussianGenerator(T_past=10, T_fut=5, batch_size=100, rho=0.7, dim=2)
    X_past, X_fut = markov_gen.sample()
    mi = markov_gen.get_mutual_information()
    print(f"Markov Gaussian MI: {mi:.4f}")
    
    # Example 2: Different kernel-based processes
    kernel_configs = {
        "AR": {
            "params": {"sigma": 0.5, "rho": 0.8},
            "desc": "Autoregressive kernel - exponential decay with memory"
        },
        "matern32": {
            "params": {"sigma": 1.0, "l": 2.0},
            "desc": "Matérn 3/2 kernel - once differentiable process"
        },
        "matern52": {
            "params": {"sigma": 1.0, "l": 2.0},
            "desc": "Matérn 5/2 kernel - twice differentiable process"
        },
        "squared_exponential": {
            "params": {"sigma": 1.0, "l": 2.0},
            "desc": "RBF kernel - infinitely differentiable, very smooth process"
        },
        "periodic": {
            "params": {"sigma": 0.5, "period": 2.0, "l": 3.0},
            "desc": "Periodic kernel - for cyclic patterns"
        },
        "rational_quadratic": {
            "params": {"sigma": 1.0, "l": 2.0, "alpha": 1.0},
            "desc": "Rational Quadratic - mixture of SE kernels at different scales"
        },
        "locally_periodic": {
            "params": {"sigma": 1.0, "period": 4.0, "l": 1.0, "decay": 10.0},
            "desc": "Locally Periodic - periodic patterns with decaying amplitude"
        }
    }

    print("\nKernel-based processes examples:")
    print("="*50)
    
    for kernel_name, config in kernel_configs.items():
        print(f"\n{kernel_name}:")
        print(f"Description: {config['desc']}")
        print(f"Parameters: {config['params']}")
        
        kernel_gen = KernelBasedGenerator(
            T_past=10,
            T_fut=5,
            batch_size=100,
            kernel_type=kernel_name,
            kernel_params=config['params'],
            data_dim=1
        )
        
        X_past, X_fut = kernel_gen.sample()
        mi = kernel_gen.get_mutual_information()
        print(f"Generated shapes: past={X_past.shape}, future={X_fut.shape}")
        print(f"Mutual Information: {mi:.4f}")
    
    # Example 3: Evolution Rate process
    print("\nEvolution Rate process:")
    print("="*50)
    evo_gen = EvoRateGenerator(T_past=10, T_fut=1, batch_size=100, rho=0.7, dim=2)
    X_past, X_fut = evo_gen.sample()
    mi = evo_gen.get_mutual_information()
    print(f"Evolution Rate MI: {mi:.4f}")

    # Test AR(p) process
    ar_gen = AutoRegressiveGenerator(
       batch_size=100,
        p=5, rho=0.7, dim=2
    )
    long_seq = ar_gen.generate_long_array(N=1000)
   
    # Test multivariate Markov process
    corr_matrix = np.array([
        [0.7, 0.2],
        [0.2, 0.7]
    ])
    mv_gen = MultivariateMarkovGenerator(
        batch_size=100,
        correlation_matrix=corr_matrix, dim=2
    )
    long_seq = mv_gen.generate_long_array(N=1000)
   




# def generateAR5(N, rho, dim=1, z0=None):
#     """
#     Génère une séquence autoregressive.
    
#     Pour i < 5, on utilise une mise à jour AR(1) (car on n'a pas encore 5 valeurs précédentes).
#     Pour i >= 5, on alterne :
#       - Si i est pair, on utilise une mise à jour AR(5) :
#           z[i] = (rho/5) * (z[i-5] + z[i-4] + z[i-3] + z[i-2] + z[i-1]) + sqrt(1 - rho**2) * eps[i]
#       - Sinon (i impair), on utilise une mise à jour AR(1) classique :
#           z[i] = rho * z[i-1] + sqrt(1 - rho**2) * eps[i]
    
#     Arguments:
#       N    : nombre total d'observations.
#       rho  : coefficient de dépendance.
#       dim  : dimension de chaque observation (par défaut 1).
#       z0   : condition initiale (si None, initialisée à 0).
      
#     Retourne:
#       z    : un tableau numpy de forme (N, dim) contenant la séquence générée.
#     """
    
#     z = np.zeros((N, dim))
    
#     if z0 is None:
#         z[0] = np.zeros(dim)
#     else:
#         z[0] = np.array(z0)
    
#     eps = np.random.randn(N, dim)
    
#     for i in range(1, N):
#         if i < 5:
#             # Pour les premiers indices, on utilise AR(1)
#             z[i] = rho * z[i-1] + np.sqrt(1 - rho**2) * eps[i]
#         else:
#             if i % 2 == 0:
#                 # Pour i pair, dépendance sur les 5 dernières valeurs (AR(5) simple à coefficients égaux)
#                 z[i] = (rho / 5.0) * (z[i-5] + z[i-4] + z[i-3] + z[i-2] + z[i-1]) \
#                        + np.sqrt(1 - rho**2) * eps[i]
#             else:
#                 # Pour i impair, on utilise AR(1)
#                 z[i] = rho * z[i-1] + np.sqrt(1 - rho**2) * eps[i]
    
#     return z



# def generateMultivariateMarkovGaussien(N, rho, dim=1, z0=None):
#     """ generating a sequence. The one tested in evoRate """
    
#     z = np.zeros((N, dim))
   
     
    
#     if z0 is None:
#         z[0] = np.zeros(dim)  
#         z[0] = np.random.randn(dim)
#     else:
#         z[0] = np.array(z0) 
        

#     eps = np.random.randn(N, dim)  #  (N, dim)
#     for i in range(1, N):
#         z[i] = rho * z[i-1] + np.sqrt(1 - rho**2) * eps[i]

#     return z


