import numpy as np
import torch


## 

class G1D():
    def __init__(self, mean = 0, sigma = 1, 
                 coef_ax = 1, stdx = 1):
        self.mean = mean
        self.sigma = sigma
        self.coef_ax = torch.tensor([coef_ax])
        self.stdx = stdx

    def sample_Y(self, n):
        Y = np.random.normal(self.mean, self.sigma, n)
        return torch.Tensor(Y)
    
    def sample_X_given_Y(self, Y, stdx):
        return torch.distributions.normal.Normal(Y, stdx).sample()
    
    def sample_A_given_X(self, X, coef_ax):
        logits = (X-self.mean)*coef_ax
        return torch.distributions.bernoulli.Bernoulli(logits=logits).sample()
    
    def sample_Z(self, n, random_seed = 0):
        torch.manual_seed(random_seed)
        np.random.seed(random_seed)
        Y = self.sample_Y(n)
        X = self.sample_X_given_Y(Y, self.stdx)
        A = self.sample_A_given_X(X, self.coef_ax)
        return (X[:, None], A, Y[:, None])



class INTRACTABLE():
    def __init__(self, coef):
        self.mean = np.zeros(5) 
        precision = np.diag(np.repeat(.5, 5))
        precision[0, 1] = -0.3
        precision[1, 0] = -0.3
        precision[0, 2:5] = -.1
        precision[2:5, 0] = -.1
        precision = 2*precision
        self.sigma = np.linalg.inv(precision)

        self.coef = coef
        
    def sample_Y(self, n):
        Y = np.random.multivariate_normal(self.mean, self.sigma, n)
        return torch.Tensor(Y)
    
    def sample_X_given_Y(self, Y):
        return torch.distributions.multivariate_normal.MultivariateNormal(Y, torch.eye(Y.shape[1])).sample()
        return torch.distributions.normal.Normal(torch.sum(Y-1, 1), 1).sample()[:, None]
    
    def sample_A_given_X(self, X, coef = torch.tensor([1, 1, 1, 1, 1])):
        logits = torch.sum((X**2-1) * coef, 1)
        # logits = X.flatten()
        return torch.distributions.bernoulli.Bernoulli(logits=logits).sample()
    
    def sample_Z(self, n, random_seed = 0):
        torch.manual_seed(random_seed)
        np.random.seed(random_seed)
        Y = self.sample_Y(n)
        X = self.sample_X_given_Y(Y)
        A = self.sample_A_given_X(X, self.coef)
        return (X, A, Y)




class MNIST():
    def __init__(self, digit=0, n_layer=10):
        self.digit = digit
        self.n_layer = n_layer
    
    def sample_Y(self, n):
        Y = torch.load("/home/causal_ksd/models_MNIST/ll" + str(self.digit) + "n_layer" + str(self.n_layer) + ".pt")
        return Y[:n, :]
    
    def sample_X_given_Y(self, Y):
        return torch.distributions.multivariate_normal.MultivariateNormal(Y, torch.eye(Y.shape[1])).sample()
    
    def sample_A_given_X(self, X):
        coef = np.zeros(X.shape[1]) / 2
        #logits = torch.sum((X-.5) * coef, 1)
        logits = torch.sum((X**2 - 1) * coef, 1)
        # logits = X.flatten()
        return torch.distributions.bernoulli.Bernoulli(logits=logits).sample()
    
    def sample_Z(self, n, random_seed = 0):
        torch.manual_seed(random_seed)
        np.random.seed(random_seed)
        Y = self.sample_Y(n)
        X = self.sample_X_given_Y(Y)
        A = self.sample_A_given_X(X)
        #print("A", A[:20])
        return (X, A, Y)




class RBM():
    def __init__(self, dvisible, dhidden):
        self.dhidden = dhidden
        self.dvisible = dvisible
    
    def sample_parameters(self):
        bias_visible = np.zeros(self.dvisible)
        bias_visible[0] = 0
        bias_hidden = np.ones(self.dhidden)
        B = np.zeros((self.dvisible, self.dhidden))
        B[0,0] = 1
        B[1, 0] = -1
        print("B: ", B)
        return bias_visible, bias_hidden, B
        
    def sample_Y(self, bias_visible, bias_hidden, B, n_burn, n):
        # Generate random visible units
        visible_units = np.random.randn(n, len(bias_visible))

        # Perform Gibbs sampling
        Y, sampled_hidden = self.gibbs_sampling(visible_units, W=B, b_hidden=bias_hidden, b_visible=bias_visible, num_steps=n_burn)

        print("Sampled visible units:")
        print(Y.shape)
        print("\nSampled hidden units:")
        print(sampled_hidden.shape)

        return torch.Tensor(Y)
    
    def sample_X_given_Y(self, Y):
        return torch.distributions.multivariate_normal.MultivariateNormal(Y, .25*torch.eye(Y.shape[1])).sample()
    
    def sample_A_given_X(self, X):
        coef = np.ones(X.shape[1]) / 5
        logits = torch.sum((X-.5) * coef, 1)
        #logits = torch.sum((X**2 - 1) * coef, 1)
        # logits = X.flatten()
        return torch.distributions.bernoulli.Bernoulli(logits=logits).sample()
    
    def sample_Z(self, n, n_burn, random_seed = 0):
        torch.manual_seed(random_seed)
        np.random.seed(random_seed)
        bias_visible, bias_hidden, B = self.sample_parameters()
        Y = self.sample_Y(bias_visible, bias_hidden, B, n_burn, n)
        X = self.sample_X_given_Y(Y)
        A = self.sample_A_given_X(X)
        return (X, A, Y, bias_visible, bias_hidden, B)
    
    def sigmoid(x):
        return 1 / (1 + np.exp(-x))
  
    def sample_binary(self, probabilities):
        return np.random.rand(*probabilities.shape) < probabilities

    def sample_gaussian(self, mean, stddev):
        return np.random.normal(mean, stddev, mean.shape)

    def gibbs_sampling(self, visible_units, W, b_hidden, b_visible, num_steps=1):
        num_samples, num_visible_units = visible_units.shape
        hidden_units = np.zeros((num_samples, W.shape[1]))

        for _ in range(num_steps):
            # Sample binary hidden units
            #hidden_probabilities = self.sigmoid(np.dot(visible_units, W) + b_hidden)
            hidden_probabilities = 1 / (1 + np.exp(- np.dot(visible_units, W) + b_hidden))
            hidden_units = self.sample_binary(hidden_probabilities)

            # Sample Gaussian visible units
            visible_mean = np.dot(hidden_units, W.T) + b_visible
            visible_units = self.sample_gaussian(visible_mean, .25)

        return visible_units, hidden_units





