import GPy
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Normal, Laplace, Gumbel, Exponential
from sklearn.linear_model import LinearRegression
import random


################### NonGaussian noise distributions ###################

class RandomDist1D(nn.Module):
    """Sample 1D distribution such that on ach noise term we can keep the function fixed
    """
    def __init__(self, h=100, activation=nn.Sigmoid(), bias=False, a_weight=-5., b_weight=5., a_bias=-1., b_bias=1.,
                 base_noise_type="Gauss", base_noise_std=1, seed=42) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(1, h, bias=bias),
            activation,
            nn.Linear(h, h, bias=bias),
            activation,
            nn.Linear(h, 1, bias=bias),
        )
        self.a_weight, self.b_weight, self.a_bias, self.b_bias = a_weight, b_weight, a_bias, b_bias
        self.y_min, self.y_max = None, None

        # Set nonlinear function. Seed for reproducibility across datasets
        torch.manual_seed(seed) 
        self._init_weights()

        if base_noise_type == 'Gauss':
            self.base_noise = Normal(0, base_noise_std)
        elif base_noise_type == 'Laplace':
            self.base_noise = Laplace(0, base_noise_std / np.sqrt(2))
        elif base_noise_type == 'Gumbel':
            self.base_noise = Gumbel(0, np.sqrt(base_noise_std) * np.sqrt(6)/np.pi)
        elif base_noise_type == 'Exponential':
            self.base_noise = Exponential(base_noise_std)
        else:
            raise NotImplementedError(f"Unknown noise type {base_noise_type}.")

    def _init_weights(self):
        def init_func(m):
            classname = m.__class__.__name__
            if hasattr(m, 'weight') and classname.find('Linear') != -1:
                nn.init.uniform_(m.weight.data, a=self.a_weight, b=self.b_weight)

            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.uniform_(m.weight.data, a=self.a_bias, b=self.b_bias)
        self.layers.apply(init_func)

    def _standardize(self, noise):
        return (noise - noise.mean())/noise.std()

    @torch.no_grad()
    def sample(self, N=1000, standardize=True):

        # Sample from base noise
        noise = self.base_noise.sample((N, 1))
        transform_noise = self.forward(noise).squeeze()
        if standardize:
            transform_noise = self._standardize(transform_noise)
        return noise.squeeze(), transform_noise
    
    def sample_d(self, N, d, standardize=True):
        noise = torch.zeros((N, d))
        for col in range(d):
            _, noise_col = self.sample(N, standardize)
            noise[:, col] += noise_col
        return noise

    @torch.no_grad()
    def forward(self, x):
        x = self.layers(x)
        return x


############################ Data Sampler #############################
class DataSimulator(object):
    def __init__(self, num_samples, num_nodes, noise_std_support, noise_type, adjacency, GP = True, lengthscale = 1, f_magn = 1):
        """
        Parameters
        ----------
        num_samples : int
            Number of samples in the data matrix
        num_nodes : int
            Number of nodes in the graph
        noise_std_support : tuple(float, float)
            Support of std deviation for noise terms
        adjacency : np.array
            Adjacency matrix representation of causal graph
        GP : bool
            If True, sample nonlinear causal mechanisms from a Gaussian Process
        """
        num_nodes, _ = adjacency.shape

        self.num_samples = num_samples
        self.num_nodes = num_nodes
        self.GP = GP
        self.lengthscale = lengthscale
        self.f_magn = f_magn
        self.adjacency = adjacency

        a, b = noise_std_support
        noise_std = torch.FloatTensor(self.num_nodes).uniform_(a, b)
        if noise_type == "gauss":
            noise_sampler = Normal(0, noise_std)
            self.noise = noise_sampler.sample((self.num_samples,))
        elif noise_type == "nonlin_weak": # .5
            noise_sampler = RandomDist1D(base_noise_type="Gauss", a_weight=-.5, b_weight=.5, seed=35)
            self.noise = noise_sampler.sample_d(N=self.num_samples, d=num_nodes, standardize=True)*noise_std
        elif noise_type == "nonlin_mid": # 1.5
            noise_sampler = RandomDist1D(base_noise_type="Gauss", a_weight=-1.5, b_weight=1.5, seed=35)
            self.noise = noise_sampler.sample_d(N=self.num_samples, d=num_nodes, standardize=True)*noise_std
        elif noise_type == "nonlin_strong": # 3
            noise_sampler = RandomDist1D(base_noise_type="Gauss", a_weight=-3, b_weight=3, seed=35)
            self.noise = noise_sampler.sample_d(N=self.num_samples, d=num_nodes, standardize=True)*noise_std
        else:
            raise ValueError(f"Required noise {noise_type} is not supported")
        
        # Assert upper triangular
        assert np.allclose(np.triu(self.adjacency, k=1), self.adjacency)


    def sample_linear_mechanism(self, X : np.array):
        """Define a linear combination of the columns of X.
        Use LinearRegression from sklearn, setting the regression coefficient 
        with uniform proability distirbution between -1 and 1
        """
        _, n_covariates = X.shape
        linear_reg = LinearRegression()
        linear_reg.coef_ = np.random.uniform(-1, 1, n_covariates) 
        while min(abs(linear_reg.coef_)) < 0.05: # avoid ~0 coefficients
            linear_reg.coef_ = np.random.uniform(-1, 1, n_covariates)
        linear_reg.intercept_ = 0
        return linear_reg.predict(X)
        

    def sampleGP(self, X, lengthscale=1):
        ker = GPy.kern.RBF(input_dim=X.shape[1],lengthscale=lengthscale,variance=self.f_magn)
        C = ker.K(X,X) # compute covariance of the kernel
        X_sample = np.random.multivariate_normal(np.zeros(len(X)),C)
        return X_sample


    def sample(self, p_linear=0.0):
        """
        Parameters
        ----------
        n : int
            Number of samples to be drawn
        p_linear : float
            Probability of linear relationship
        """
        X = torch.clone(self.noise)

        if self.GP:
            for i in range(self.num_nodes):
                parents = np.nonzero(self.adjacency[:,i])[0]
                if len(np.nonzero(self.adjacency[:,i])[0]) > 0:
                    X_par = X[:,parents]
                    
                    # Randomly choice linear-nonlinear mechanism
                    if np.random.binomial(n=1, p=p_linear) == 1:
                        X[:, i] += self.sample_linear_mechanism(np.array(X_par))
                    else:
                        X[:, i] += torch.tensor(self.sampleGP(np.array(X_par), self.lengthscale))
        else:
            raise ValueError("Need more options for nonlinear mechanisms generation")

        return X
