import numpy as np
import torch
import torch.nn as nn
import torch.distributions as D


class G1D_learner(nn.Module):

    def __init__(self, mean = torch.Tensor([1.]), var = torch.Tensor([1.])):
        super().__init__()
        self.mean = nn.Parameter(mean)
        self.var = var
        self.dist = D.Normal(torch.Tensor(mean), torch.Tensor(var))

    def sample(self, n):
        samp = self.dist.sample(torch.Size([n, ]))
        return samp
    
    def log_prob(self, x):
        mu1 = self.mean
        v1 = self.var
        n = torch.Tensor([np.sqrt(np.pi * 2) * v1])
        x1 = x.T
        output = -torch.log(n) - 0.5 * (x1 - mu1) ** 2 / v1 ** 2 
        return output

    def score(self, x):
        
        mu1 = self.mean
        v1 = self.var
        x1 = x.t()
        dx1 = -(x1 - mu1) / v1 ** 2 
        return dx1.t()
    
    
class INTRACTABLE_learner(nn.Module):

    def __init__(self, theta4 = torch.Tensor([1]), theta5 = torch.Tensor([1])):
        super().__init__()
        self.theta4 = nn.Parameter(theta4)
        self.theta5 = nn.Parameter(theta5)

    def score(self, x):
        theta4 = self.theta4
        theta5 = self.theta5
        x1, x2, x3, x4, x5 = x.t()
        dx1 = -x1 + 0.6*x2 + 0.2*(x3 + x4 + x5)
        dx2 = -x2 + 0.6*x1 
        dx3 = -x3 + 0.2*x1 
        dx4 = -x4 + 0.2*x1 + theta4 * (1 - torch.tanh(x4) ** 2)
        dx5 = -x5 + 0.2*x1 + theta5 * (1 - torch.tanh(x5) ** 2) 
        return torch.stack((dx1, dx2, dx3, dx4, dx5)).t()


class RBM_learner(nn.Module):

    def __init__(self, B, bias_visible, bias_hidden):
        super().__init__()
        self.B = nn.Parameter(B)
        self.b = nn.Parameter(bias_visible)
        self.c = nn.Parameter(bias_hidden)
        

    def transf(self, y):
        return (torch.exp(2*y) - 1) /(torch.exp(2*y) + 1)

    def score(self, x):
        B = self.B 
        b = self.b
        c = self.c
        return b - self.transf(x@B + c) @ B.t()


class MNIST_learner(nn.Module):

    def __init__(self, coef, n, digit, n_layer):
        super().__init__()
        self.coef = nn.Parameter(coef)
        self.sco = torch.load("/home/causal_ksd/models_MNIST/sco" + str(digit) + "n_layer" + str(n_layer) + ".pt")[:n, :, :]

    def score(self, x):
        return self.sco @ self.coef
    







