import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.parameter import Parameter
import numpy as np

from torch.linalg import pinv as tpinv



class VaDE(torch.nn.Module):
    def __init__(self, in_dim=2, latent_dim=2, n_classes=4, covariance = 'diag'):
        super(VaDE, self).__init__()

        self.pi_prior = Parameter(torch.ones(n_classes)/n_classes)
        self.mu_prior = Parameter(10*torch.randn(n_classes, latent_dim))
        self.covariance = covariance
        if covariance == 'full':
            self.sqrt_var_prior = Parameter(7*torch.randn(n_classes, latent_dim, latent_dim))
            self.var_log_det = torch.randn(n_classes)
        else:
            self.var_prior = Parameter(torch.randn(n_classes, latent_dim)/10)
            self.log_var_prior = Parameter(torch.randn(n_classes, latent_dim)-3)
        
        self.epoch = 0
        
        if in_dim<200:        
            self.fc1 = nn.Linear(in_dim, 64) #Encoder
            self.fc2 = nn.Linear(64,512)
            self.fc3 = nn.Linear(512, 64) 

            self.mu = nn.Linear(64, latent_dim) #Latent mu
            self.log_var = nn.Linear(64, latent_dim) #Latent logvar

            self.fc4 = nn.Linear(latent_dim, 64) 
            self.fc5 = nn.Linear(64, 512)
            self.fc6 = nn.Linear(512, 512)
            self.fc7 = nn.Linear(512, in_dim) #Decoder
        else:
            self.fc1 = nn.Linear(in_dim, 512) #Encoder
            self.fc2 = nn.Linear(512, 512)
            self.fc3 = nn.Linear(512, 2048) 

            self.mu = nn.Linear(2048, latent_dim) #Latent mu
            self.log_var = nn.Linear(2048, latent_dim) #Latent logvar

            self.fc4 = nn.Linear(latent_dim, 2048) 
            self.fc5 = nn.Linear(2048, 512)
            self.fc6 = nn.Linear(512, 512)
            self.fc7 = nn.Linear(512, in_dim)

    def encode(self, x):
        # print("Encoding", x)
        h = F.leaky_relu(self.fc1(x))
        h = F.leaky_relu(self.fc2(h))
        h = F.leaky_relu(self.fc3(h))
        # print("After encoding, got h =", h)
        return self.mu(h), self.log_var(h)

    def decode(self, z):
        h = F.leaky_relu(self.fc4(z))
        h = F.leaky_relu(self.fc5(h))
        h = F.leaky_relu(self.fc6(h))
        # print("h =", h)
        # return F.sigmoid(self.fc7(h)) # Only needed for binary inputs
        x = self.fc7(h)
        # print("x = ", x)
        return x

    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, log_var = self.encode(x)
        # print("mu, log_var =", mu, log_var)
        z = self.reparameterize(mu, log_var)
        flag = np.random.rand()
        if (flag>0.7*np.power(0.85, self.epoch)) or True: #remove later
            x_hat = self.decode(z)
        else:
            gamma = self.compute_gamma(z, self.pi_prior)
            pred = torch.argmax(gamma, dim=1)
            if self.covariance == "full":
                transform = tpinv(self.sqrt_var_prior[pred])
                z_new = self.mu_prior[pred] + torch.matmul(transform, torch.randn(transform.shape[1]))                
            else:
                z_new = self.reparameterize(self.mu_prior[pred], self.log_var_prior[pred])
            x_hat = self.decode(z_new)
        return x_hat, mu, log_var, z
    
    def compute_gamma(self, z, p_c):
        if self.covariance == 'full':
            #samples = self.gmm.sample(1000)[0]
            #plt.scatter(samples[:, 0], samples[:, 1])
            #znp = np.asarray(z.detach())
            #plt.scatter(znp[:,0], znp[:, 1], s=20, c = "#2ca02c")
            #print(self.gmm.covariances_)
            #plt.show()
            h = (z.unsqueeze(1) - self.mu_prior)
            h = torch.matmul(self.sqrt_var_prior, h.permute(1, 2, 0)).permute(2, 0, 1).pow(2)
            h += torch.Tensor([np.log(np.pi*2)])
            #.to(self.device)
            #print(torch.sum(h, dim=2))
            #print(torch.log(p_c + 1e-9).unsqueeze(0))
            #print(self.VaDE.var_log_det)
            p_z_c = torch.exp(torch.log(p_c + 1e-9).unsqueeze(0) - 0.5 * torch.sum(h, dim=2)+ 0.5*self.var_log_det)  + 1e-9
            #print(p_z_c)
        else:
            h = (z.unsqueeze(1) - self.mu_prior).pow(2) / self.log_var_prior.exp()
            h += self.log_var_prior
            h += torch.Tensor([np.log(np.pi*2)])
            #.to(self.device)
            p_z_c = torch.exp(torch.log(p_c + 1e-9).unsqueeze(0) - 0.5 * torch.sum(h, dim=2)) + 1e-9
        gamma = p_z_c / torch.sum(p_z_c, dim=1, keepdim=True)
        return gamma


class Autoencoder(torch.nn.Module):
    def __init__(self, in_dim=2, latent_dim=2):
        super(Autoencoder, self).__init__()
        if in_dim<200:        
            self.fc1 = nn.Linear(in_dim, 64) #Encoder
            self.fc2 = nn.Linear(64,512)
            self.fc3 = nn.Linear(512, 64) 

            self.mu = nn.Linear(64, latent_dim) #Latent mu

            self.fc4 = nn.Linear(latent_dim, 64) 
            self.fc5 = nn.Linear(64, 512)
            self.fc6 = nn.Linear(512, 512)
            self.fc7 = nn.Linear(512, in_dim) #Decoder
        else:
            self.fc1 = nn.Linear(in_dim, 512) #Encoder
            self.fc2 = nn.Linear(512, 512)
            self.fc3 = nn.Linear(512, 2048) 

            self.mu = nn.Linear(2048, latent_dim) #Latent mu

            self.fc4 = nn.Linear(latent_dim, 2048) 
            self.fc5 = nn.Linear(2048, 512)
            self.fc6 = nn.Linear(512, 512)
            self.fc7 = nn.Linear(512, in_dim)

    def encode(self, x):
        h = F.leaky_relu(self.fc1(x))
        h = F.leaky_relu(self.fc2(h))
        h = F.leaky_relu(self.fc3(h))
        return self.mu(h)

    def decode(self, z):
        h = F.leaky_relu(self.fc4(z))
        h = F.leaky_relu(self.fc5(h))
        h = F.leaky_relu(self.fc6(h))
        # return F.sigmoid(self.fc7(h)) # Only needed for binary inputs
        x = self.fc7(h)
        return x

    def forward(self, x):
        z = self.encode(x)
        x_hat = self.decode(z)
        return x_hat