import os
import GPy
import random
import igraph as ig
import numpy as np
import copy

import torch
from torch.distributions import MultivariateNormal, Normal, Laplace, Gumbel
from torch.distributions import Uniform


class Dist(object):
    def __init__(self, d, SummaryDAG, AutoREG_1st, AutoREG_2nd, gunc = 'linear', func = 'linear', norm = False, 
                 noise_std = [1, 0.4, 0.4, 0.4], noise_type = ['Gauss','Gauss','Gauss','Gauss'], 
                 lengthscale = 1, f_magn = 1, GraNDAG_like = False):
        
        """Simulate Data Distribution given a DAG.
        Args:
            d (int): num of nodes
            SummaryDAG: adjacency matrix of DAG
            AutoREG_1st: adjacency matrix of DAG for 1st Auto-Regression.
            AutoREG_2nd: adjacency matrix of DAG for 2nd Auto-Regression.
            gunc: GP - sampling Gaussian processes with a unit bandwidth RBF kernel.
            func: GP - sampling Gaussian processes with a unit bandwidth RBF kernel.
            norm: Whether to normalize the panel data at each time
            noise_std: standard deviation of noise variable {E0, E1, E2, E3}
            noise_type: distribution of noise variable {E0, E1, E2, E3}
            lengthscale: scale parameter of RBF kernel
            f_magn: C = f_magn * K(X,X)
            GraNDAG_like: use GraNDAG to simulate data distribution
        """

        self.d = d
        if isinstance(noise_std, (int, float)):
            noise_std = noise_std * torch.ones(self.d)
        self.func = func
        self.gunc = gunc
        self.norm = norm
        self.lengthscale = lengthscale
        self.f_magn = f_magn
        self.GraNDAG_like = GraNDAG_like
        
        if self.GraNDAG_like:
            noise_std = [torch.ones(d), torch.ones(d), torch.ones(d), torch.ones(d)]

        self.noiseE0 = self.noiseDist(noise_type[0], noise_std[0])
        self.noiseE1 = self.noiseDist(noise_type[1], noise_std[1])
        self.noiseE2 = self.noiseDist(noise_type[2], noise_std[2])
        self.noiseE3 = self.noiseDist(noise_type[3], noise_std[3])
        
        self.SummaryDAG = SummaryDAG
        self.AutoREG_1st = AutoREG_1st
        self.AutoREG_2nd = AutoREG_2nd

        assert(np.allclose(self.SummaryDAG, np.triu(self.SummaryDAG)))


    def noiseDist(self, noise_type_, noise_std_):
        if isinstance(noise_std_, (int, float)):
            noise_std_ = noise_std_ * torch.ones(self.d)

        if noise_type_ == 'Gauss':
            return Normal(0, noise_std_) # give standard deviation
        elif noise_type_ == 'Laplace':
            return Laplace(0, noise_std_ / np.sqrt(2))
        elif noise_type_ == 'Uniform':
            print("Hello Uniform.")
            return Uniform( noise_std_ * (-torch.ones(self.d)), noise_std_ * torch.ones(self.d))
        else:
            raise NotImplementedError("Unknown noise type for noise E1.")


    def sampleGP(self, X, lengthscale=1):
        ker = GPy.kern.RBF(input_dim=X.shape[1],lengthscale=lengthscale,variance=self.f_magn)
        C = ker.K(X,X)
        X_sample = np.random.multivariate_normal(np.zeros(len(X)),C)
        return X_sample


    def sample(self, path, n):

        _noiseE0 = self.noiseE0.sample((n,)) # n x d noise matrix
        _noiseE1 = self.noiseE1.sample((n,)) # n x d noise matrix
        _noiseE2 = self.noiseE2.sample((n,)) # n x d noise matrix
        _noiseE3 = self.noiseE3.sample((n,)) # n x d noise matrix

        _X0 = copy.deepcopy(_noiseE0)
        _X1 = self.nextPanel_1(copy.deepcopy(_X0), copy.deepcopy(_noiseE1))
        _X2 = self.nextPanel_23(copy.deepcopy(_X0), copy.deepcopy(_X1), copy.deepcopy(_noiseE2))
        _X3 = self.nextPanel_23(copy.deepcopy(_X1), copy.deepcopy(_X2), copy.deepcopy(_noiseE3))

        _X = torch.stack([_X0, _X1, _X2, _X3], dim=0)
        _Noise = torch.stack([_noiseE0, _noiseE1, _noiseE2, _noiseE3], dim=0)

        torch.save(_X,     path+'X.pt')
        torch.save(_Noise, path+'Noise.pt')

        print("Save to the Dir: {}. ".format(path))

        return _X, _Noise

    def non_linear_back(self, X, function):
        if function == 'sin':
            return torch.sin(X)
        elif function == 'pow2':
            return 0.1*torch.pow(X+2,2)
        elif function == 'pow3':
            return 0.1*torch.pow(X+2,2)
        elif function == 'poly':
            return 0.1*torch.pow(X+2,2)
        elif function == 'sigmoid':
            return 3/(1+torch.exp(X))
        else:
            return X

    def nextPanel_1(self, X0, X1):
        # !!! Only works if SummaryDAG matrix is upper triangular !!!
        noise_var = np.zeros(self.d)
        if self.func == 'GP' or self.func == 'Gaussian Processes':
            for i in range(self.d):
                parents0 = np.nonzero(self.AutoREG_1st[:,i])[0]
                parents1 = np.nonzero(self.SummaryDAG[:,i])[0]
                if self.GraNDAG_like:
                    if len(parents1) == 0: # For roots, noise variance sampled U(1,2)
                        noise_var[i] = np.random.uniform(1,2)
                    else: # Otherwise, noise variance sampled U(0.4,0.8)
                        noise_var[i] = np.random.uniform(0.4,0.8)
                    X1[:, i] = np.sqrt(noise_var[i]) * X1[:, i]
                X1[:, i] += torch.tensor(self.sampleGP(np.array(X0[:,i]), self.lengthscale))
                if len(np.nonzero(self.AutoREG_1st[:,i])[0]) > 0:
                    X_par0 = X0[:,parents1]
                    X1[:, i] += torch.tensor(self.sampleGP(np.array(X_par0), self.lengthscale)) * 0.1
                if len(np.nonzero(self.SummaryDAG[:,i])[0]) > 0:
                    X_par1 = X1[:,parents1]
                    X1[:, i] += torch.tensor(self.sampleGP(np.array(X_par1), self.lengthscale))
        else:
            for i in range(self.d):
                X1[:, i] += self.non_linear_back(X0[:,i], self.gunc)
                for j in np.nonzero(self.AutoREG_1st[:,i])[0]:
                    X1[:, i] += self.non_linear_back(X0[:,j], self.gunc) * 0.1
                for j in np.nonzero(self.SummaryDAG[:,i])[0]:
                    X1[:, i] += self.non_linear_back(X1[:,j], self.func)
                
        if self.norm:
            X1 = (X1 - X1.mean(0))/X1.std(0)

        return X1

    def nextPanel_23(self, X0, X1, X2):
        # !!! Only works if SummaryDAG matrix is upper triangular !!!
        noise_var = np.zeros(self.d)
        if self.func == 'GP' or self.func == 'Gaussian Processes':
            for i in range(self.d):
                parents2nd = np.nonzero(self.AutoREG_2nd[:,i])[0]
                parents1st = np.nonzero(self.AutoREG_1st[:,i])[0]
                parentsDAG = np.nonzero(self.SummaryDAG[:,i])[0]
                if self.GraNDAG_like:
                    if len(parents1) == 0: # For roots, noise variance sampled U(1,2)
                        noise_var[i] = np.random.uniform(1,2)
                    else: # Otherwise, noise variance sampled U(0.4,0.8)
                        noise_var[i] = np.random.uniform(0.4,0.8)
                    X2[:, i] = np.sqrt(noise_var[i]) * X2[:, i]

                X2[:, i] += torch.tensor(self.sampleGP(np.array(X1[:,i]), self.lengthscale))
                if len(np.nonzero(self.AutoREG_2nd[:,i])[0]) > 0:
                    X_par2nd = X0[:,parents2nd]
                    X2[:, i] += torch.tensor(self.sampleGP(np.array(X_par2nd), self.lengthscale)) * 0.1
                if len(np.nonzero(self.AutoREG_1st[:,i])[0]) > 0:
                    X_par1st = X1[:,parents1st]
                    X2[:, i] += torch.tensor(self.sampleGP(np.array(X_par1st), self.lengthscale)) * 0.1
                if len(np.nonzero(self.SummaryDAG[:,i])[0]) > 0:
                    X_parDAG = X2[:,parentsDAG]
                    X2[:, i] += torch.tensor(self.sampleGP(np.array(X_parDAG), self.lengthscale))
        else:
            for i in range(self.d):
                X2[:, i] += self.non_linear_back(X1[:,i], self.gunc)
                for j in np.nonzero(self.AutoREG_2nd[:,i])[0]:
                    X2[:, i] += self.non_linear_back(X0[:,j], self.gunc) * 0.1
                for j in np.nonzero(self.AutoREG_1st[:,i])[0]:
                    X2[:, i] += self.non_linear_back(X1[:,j], self.gunc) * 0.1
                for j in np.nonzero(self.SummaryDAG[:,i])[0]:
                    X2[:, i] += self.non_linear_back(X2[:,j], self.func)
                
        if self.norm:
            X2 = (X2 - X2.mean(0))/X2.std(0)
        return X2