from matplotlib.style import available
import networkx as nx
import numpy as np 
import torch
import math

from models.functions_scout import indMLPFunction, nonlinearMLP, NoisyFunction

def standard_normal_logprob(z, noise_scale=0.5):
    logZ = -0.5 * math.log(2 * math.pi * noise_scale**2)
    return logZ - z**2 / (2 * noise_scale**2)

def make_non_cotractive(weights):
    s = np.linalg.svd(weights, compute_uv=False)
    scale = 1.0
    if s[0] <= 1.0:
        scale = 2/s[0]
    
    return scale * weights 

def make_contractive(weights):
    s = np.linalg.svd(weights, compute_uv=False)
    scale=1.1
    if s[0] >= 1.0:
        scale = 1.1 * s[0]
    
    return weights/scale

class linearSEM:

    """
    -------------------------------------------------------------------
    This class models a Linear Structural Equation Model (Linear SEM)
    -------------------------------------------------------------------
    The model is initialized with the number of nodes in the graph and
    the absolute minimum and maximum weights for the edges. 
    """
    def __init__(self, graph, abs_weight_low=0.2, abs_weight_high=0.9, noise_scale=0.5, contractive=True, noisy=False, noisy_weight=-1):
        self.graph = graph
        self.abs_weight_low = abs_weight_low 
        self.abs_weight_high = abs_weight_high
        self.contractive = contractive

        self.n_nodes = len(graph.nodes)
        
        self.weights = np.random.uniform(self.abs_weight_low, self.abs_weight_high, size=(self.n_nodes, self.n_nodes))
        self.weights *= 2 * np.random.binomial(1, 0.5, size=self.weights.shape) - 1
        adjacency = nx.to_numpy_array(self.graph)
        self.weights *= adjacency

        self.noise_scale = noise_scale

        self.noisy = noisy
        self.noisy_weight = noisy_weight


        if not self.contractive:
            self.weights = make_non_cotractive(self.weights)
        else:
            self.weights = make_contractive(self.weights)

        if self.noisy:
            self.weights_i = self.noisy_weight * self.weights
        else:
            self.weights_i = self.weights
       
        if not self.contractive:
            self.weights_i = make_non_cotractive(self.weights_i)
        else:
            self.weights_i = make_contractive(self.weights_i)

    def generateData(self, n_samples, intervention_set=[None], lat_provided=False, latent_vec=None, return_latents=False, shift_scale=0, intervention_scale=0.5, noise_type='gaussian'):
        # set intervention_set = [None] for purely observational data.
        self.shift_scale = shift_scale
        self.intervention_scale = intervention_scale
        self.noise_type = noise_type
        observed_set = np.setdiff1d(np.arange(self.n_nodes), intervention_set)
        U = np.zeros((self.n_nodes, self.n_nodes))
        U[observed_set, observed_set] = 1
        I = np.eye(self.n_nodes)

        if lat_provided:
            E = latent_vec.T
        else:
            E = np.zeros((self.n_nodes, n_samples))
            if self.noise_type == 'gaussian':
                if len(observed_set) > 0:
                    E[observed_set,:] = self.noise_scale * np.random.randn(len(observed_set), n_samples)
                if intervention_set[0] != None:
                    E[intervention_set,:] = self.intervention_scale * np.random.randn(len(intervention_set), n_samples) + self.shift_scale
            elif self.noise_type == 'exponential':
                if len(observed_set) > 0:
                    E[observed_set,:] = np.random.exponential(scale=self.noise_scale, size=(len(observed_set), n_samples))
                if intervention_set[0] != None:
                    E[intervention_set,:] = np.random.exponential(scale=self.intervention_scale, size=(len(intervention_set), n_samples)) + self.shift_scale
            elif self.noise_type == 'gumbel':
                if len(observed_set) > 0:
                    E[observed_set,:] = np.random.gumbel(loc=0.0, scale=self.noise_scale, size=(len(observed_set), n_samples))
                if intervention_set[0] != None:
                    E[intervention_set,:] = np.random.gumbel(loc=0.0, scale=self.intervention_scale, size=(len(intervention_set), n_samples)) + self.shift_scale
            elif self.noise_type == 'laplace':
                if len(observed_set) > 0:
                    E[observed_set,:] = np.random.laplace(loc=0.0, scale=self.noise_scale, size=(len(observed_set), n_samples))
                if intervention_set[0] != None:
                    E[intervention_set,:] = np.random.laplace(loc=0.0, scale=self.intervention_scale, size=(len(intervention_set), n_samples)) + self.shift_scale
            else:
                raise ValueError(f"Unknown noise_type: {self.noise_type}")

        X = np.linalg.inv(I - (U @ self.weights.T) - ((I-U) @ self.weights_i.T)) @ (E)

        # The final data matrix is dimensions - n_samples X self.nodes
        if return_latents:
            return X.T, E.T
            
        return X.T

class nonlinearSEM:
    """
    ----------------------------------------------------------------------
    This class models a Nonlinear Structural Equation Model (Linear SEM)
    ----------------------------------------------------------------------
    The nonlinear function is taken from models.functions 
    """

    def __init__(self, graph, lip_const=0.9, fun_type='sin-mlp', act_fun='tanh', device=None, noise_scale=0.5, n_hidden=1, bias=False, contractive=True, noisy=False, noisy_weight=-1):
        self.lip_const = lip_const 
        self.graph = graph 
        self.n_nodes = len(graph.nodes)
        self.act_fun = act_fun
        self.n_hidden = n_hidden
        self.bias = bias
        self.noisy = noisy  # whether to use noisy function for interventions
        self.contractive = contractive 
        self.noisy_weight = noisy_weight
        if self.contractive:
            self.lip_const = 2.0

        if fun_type == 'mul-mlp':
            self.f = indMLPFunction(n_nodes=self.n_nodes, 
                                    lip_constant=self.lip_const,
                                    activation=self.act_fun,
                                    n_layers=n_hidden,
                                    full_input=False,
                                    graph_given=True,
                                    graph=self.graph, 
                                    bias=self.bias)

        else:
            self.f = nonlinearMLP(n_nodes=self.n_nodes, 
                                  lip_constant=self.lip_const,
                                  n_layers=self.n_hidden, 
                                  bias=self.bias,
                                  activation_fn=self.act_fun, 
                                  graph_given=True, 
                                  graph=self.graph)

        if self.noisy == True:
            self.f_i = NoisyFunction(self.f, noisy_weight=self.noisy_weight)
        else:
            self.f_i = self.f  # initially set to be the same as f

        if device is not None:
            self.f = self.f.to(device)
            self.f_i = self.f_i.to(device)
        self.device = device
        self.noise_scale = noise_scale

        
    def generateData(self, n_samples, intervention_set=[None], lat_provided=False, latent_vec=None, n_iter=30, return_latents=False, intervention_scale=0.5, shift_scale=0, noise_type='gaussian'):
        # set intervention_set = [None] for purely observational data
        self.intervention_scale = intervention_scale
        self.shift_scale = shift_scale
        self.noise_type = noise_type

        with torch.no_grad():
            observed_set = np.setdiff1d(np.arange(self.n_nodes), intervention_set)
            U = torch.zeros(self.n_nodes, self.n_nodes, device=self.device).float()
            U[observed_set, observed_set] = 1
                        
        if lat_provided:
            E = latent_vec.T
            if not isinstance(E, torch.Tensor):
                # ensure torch tensor on correct device/dtype
                E = torch.tensor(E, device=self.device, dtype=torch.float)
            else:
                E = E.to(device=self.device, dtype=torch.float)
        else:
            E = torch.zeros((n_samples, self.n_nodes), device=self.device, dtype=torch.float)

            if self.noise_type == 'gaussian':
                if len(observed_set) > 0:
                    E[:, observed_set] = self.noise_scale * torch.randn(n_samples, len(observed_set), device=self.device, dtype=torch.float)
                if intervention_set[0] is not None:
                    E[:, intervention_set] = self.intervention_scale * torch.randn(n_samples, len(intervention_set), device=self.device, dtype=torch.float) + self.shift_scale
            elif self.noise_type == 'exponential':
                if len(observed_set) > 0:
                    dist_obs = torch.distributions.Exponential(rate=1.0 / float(self.noise_scale))
                    E[:, observed_set] = dist_obs.sample((n_samples, len(observed_set))).to(device=self.device, dtype=torch.float)
                if intervention_set[0] is not None:
                    dist_int = torch.distributions.Exponential(rate=1.0 / float(self.intervention_scale))
                    E[:, intervention_set] = dist_int.sample((n_samples, len(intervention_set))).to(device=self.device, dtype=torch.float) + self.shift_scale
            elif self.noise_type == 'gumbel':
                dist_obs = torch.distributions.Gumbel(loc=0.0, scale=float(self.noise_scale))
                dist_int = torch.distributions.Gumbel(loc=0.0, scale=float(self.intervention_scale))
                if len(observed_set) > 0:
                    E[:, observed_set] = dist_obs.sample((n_samples, len(observed_set))).to(device=self.device, dtype=torch.float)
                if intervention_set[0] is not None:
                    E[:, intervention_set] = dist_int.sample((n_samples, len(intervention_set))).to(device=self.device, dtype=torch.float) + self.shift_scale
            elif self.noise_type == 'laplace':
                dist_obs = torch.distributions.Laplace(loc=0.0, scale=float(self.noise_scale))
                dist_int = torch.distributions.Laplace(loc=0.0, scale=float(self.intervention_scale))
                if len(observed_set) > 0:
                    E[:, observed_set] = dist_obs.sample((n_samples, len(observed_set))).to(device=self.device, dtype=torch.float)
                if intervention_set[0] is not None:
                    E[:, intervention_set] = dist_int.sample((n_samples, len(intervention_set))).to(device=self.device, dtype=torch.float) + self.shift_scale
            else:
                raise ValueError(f"Unknown noise_type: {self.noise_type}")

            I = torch.eye(self.n_nodes, device=self.device, dtype=torch.float)
            X = torch.randn(n_samples, self.n_nodes, device=self.device, dtype=torch.float)
            for _ in range(n_iter):
                X = (self.f(X) @ U) + (self.f_i(X) @ (I - U)) + E
        
        if return_latents:
            return X.cpu().numpy(), E.cpu().numpy()
        else:
            return X.detach().cpu().numpy()



