import math
import torch
import numpy as np
from torch import optim
import torch.nn.functional as F
from sklearn.mixture import GaussianMixture
# from sklearn.utils.linear_assignment_ import linear_assignment
from hungarian import linear_assignment as linear_assignment

from models import Autoencoder, VaDE

from draw import draw_all, draw_together

import matplotlib.pyplot as plt

from scipy.linalg import sqrtm, pinv
from torch.linalg import pinv as tpinv


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Linear") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        
def weights_init_eye_normal(m):
    classname = m.__class__.__name__
    if classname.find("Linear") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.03) 
        m.weight.data = m.weight.data + 0.1*torch.eye(m.out_features, m.in_features)
        
def weights_init_ort(m):
    classname = m.__class__.__name__
    if classname.find("Linear") != -1:
        torch.nn.init.orthogonal_(m.weight.data)

class TrainerVaDE:
    """This is the trainer for the Variational Deep Embedding (VaDE).
    """
    def __init__(self, args, device, dataloader, covariance = "diag"):
        self.autoencoder = Autoencoder(args.in_dim, args.latent_dim).to(device)
        self.VaDE = VaDE(args.in_dim, args.latent_dim, args.n_classes, covariance = covariance).to(device)
        self.dataloader = dataloader
        self.device = device
        self.args = args
        self.covariance = covariance


    def pretrain(self):
        """Here we train an stacked autoencoder which will be used as the initialization for the VaDE. 
        This initialization is usefull because reconstruction in VAEs would be weak at the begining
        and the models are likely to get stuck in local minima.
        """
        optimizer = optim.Adam(self.autoencoder.parameters(), lr=0.0001)
        self.autoencoder.apply(weights_init_ort) #intializing weights using normal distribution.
        self.autoencoder.train()
        print('Training the autoencoder...')
        for epoch in range(15):
            total_loss = 0
            for x, _ in self.dataloader:
                optimizer.zero_grad()
                x = x.to(self.device)
                x_hat = self.autoencoder(x)
                # print(x, x_hat)
                loss = F.mse_loss(x_hat, x, reduction = 'mean') * 1. / (2 * self.args.noise_var) # Use this for non-binary
                # print("pretrain vae loss =", loss)
                # loss = F.binary_cross_entropy(x_hat, x, reduction='mean') # just reconstruction
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            print('Training Autoencoder... Epoch: {}, Loss: {}'.format(epoch, total_loss))
        self.train_GMM() #training a GMM for initialize the VaDE
        self.save_weights_for_VaDE() #saving weights for the VaDE


    def train_GMM(self):
        """It is possible to fit a Gaussian Mixture Model (GMM) using the latent space 
        generated by the stacked autoencoder. This way, we generate an initialization for 
        the priors (pi, mu, var) of the VaDE model.
        """
        print('Fiting Gaussian Mixture Model...')
        x = torch.cat([data[0] for data in self.dataloader]).view(-1, self.args.in_dim).to(self.device) #all x samples.
        z = self.autoencoder.encode(x)
        if self.covariance == 'full':
            self.gmm = GaussianMixture(n_components=self.args.n_classes, covariance_type='full') 
        else:
            self.gmm = GaussianMixture(n_components=self.args.n_classes, covariance_type='diag')
        self.gmm.fit(z.cpu().detach().numpy())
        
        samples = self.gmm.sample(1000)[0]
        
        
    def retrain_GMM(self):
        x = torch.cat([data[0] for data in self.dataloader]).view(-1, self.args.in_dim).to(self.device) #all x samples.
        z, _ = self.VaDE.encode(x)
        if self.covariance == 'full':
            self.gmm = GaussianMixture(n_components=self.args.n_classes, covariance_type='full') 
        else:
            self.gmm = GaussianMixture(n_components=self.args.n_classes, covariance_type='diag')
        self.gmm.fit(z.cpu().detach().numpy())
        
        
        


    def save_weights_for_VaDE(self):
        """Saving the pretrained weights for the encoder, decoder, pi, mu, var.
        """
        print('Saving weights.')
        state_dict = self.autoencoder.state_dict()

        self.VaDE.load_state_dict(state_dict, strict=False)
        self.VaDE.pi_prior.data = torch.from_numpy(self.gmm.weights_).float().to(self.device)
        self.VaDE.mu_prior.data = torch.from_numpy(self.gmm.means_).float().to(self.device)
        if self.covariance == 'full':
            cov_shape = self.gmm.covariances_.shape
            B = np.asarray([pinv(sqrtm(self.gmm.covariances_[i])).astype(float) for i in range(cov_shape[0])])
            #for i in range(cov_shape[0]):
            #    print(np.dot(np.dot(B[i], B[i]), self.gmm.covariances_[i]))
            self.VaDE.sqrt_var_prior.data = torch.from_numpy(B).float().to(self.device)
            print(self.VaDE.sqrt_var_prior.shape, self.VaDE.sqrt_var_prior) 
            #self.VaDE.var_prior.data = torch.from_numpy(self.gmm.covariances_).float().to(self.device)
        else:
            self.VaDE.var_prior.data = torch.from_numpy(self.gmm.covariances_).float().to(self.device)
            self.VaDE.log_var_prior.data = torch.log(torch.from_numpy(self.gmm.covariances_)).float().to(self.device)
        torch.save(self.VaDE.state_dict(), self.args.pretrained_path)    

    def train(self):
        """
        """
        if self.args.pretrain==True:
            self.VaDE.load_state_dict(torch.load(self.args.pretrained_path,
                                                 map_location=self.device))
        else:
            self.VaDE.apply(weights_init_ort)
        self.optimizer = optim.Adam(self.VaDE.parameters(), lr=self.args.lr)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
                    self.optimizer, step_size=10, gamma=0.7)
        print('Training VaDE...')
        for epoch in range(self.args.epochs):
            self.train_VaDE(epoch)
            self.test_VaDE(epoch)
            lr_scheduler.step()
            self.VaDE.epoch = epoch

                
                

    def train_VaDE(self, epoch):
        self.VaDE.train()

        total_loss = 0
        for x, _ in self.dataloader:
            self.optimizer.zero_grad()
            x = x.to(self.device)
            x_hat, mu, log_var, z = self.VaDE(x)
            
            with torch.no_grad():
                if self.covariance == 'full':
                    for i in range(self.VaDE.sqrt_var_prior.shape[0]):
                        self.VaDE.var_log_det[i] = torch.log(torch.det(self.VaDE.sqrt_var_prior[i]).pow(2))
            
            
            #print('Before backward: {}'.format(self.VaDE.pi_prior))
            loss = self.compute_loss(x, x_hat, mu, log_var, z)
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
            #print('After backward: {}'.format(self.VaDE.pi_prior))
        print('Training VaDE... Epoch: {}, Loss: {}'.format(epoch, total_loss))


    def test_VaDE(self, epoch):
        self.VaDE.eval()
        with torch.no_grad():
            total_loss = 0
            y_true, y_pred = [], []
            z_vals = []
            x_hat_vals = []
            x_vals = []
            for x, true in self.dataloader:
                x = x.to(self.device)
                x_hat, mu, log_var, z = self.VaDE(x)
                gamma = self.compute_gamma(z, self.VaDE.pi_prior)
                pred = torch.argmax(gamma, dim=1)
                loss = self.compute_loss(x, x_hat, mu, log_var, z)
                total_loss += loss.item()
                y_true.extend(true.numpy())
                y_pred.extend(pred.cpu().detach().numpy())
                z_vals.extend(z.numpy())
                x_hat_vals.extend(x_hat.numpy())
                x_vals.extend(x.numpy())

            acc = self.cluster_acc(np.array(y_true), np.array(y_pred))
            print('Testing VaDE... Epoch: {}, Loss: {}, Acc: {}'.format(epoch, total_loss, acc[0]))

            # Visualizing the learnt prior
            if (epoch % 5 == 0):
                # generate random labels for a bunch of points and plots
                # n_clusters = self.args.n_classes
                n_data_points = 10000
                labels = torch.multinomial(self.VaDE.pi_prior, n_data_points, replacement = True)
                if self.covariance == "full":
                    transform = torch.ones(self.VaDE.sqrt_var_prior.shape)
                    for i in range(self.VaDE.sqrt_var_prior.shape[0]):
                        transform[i] = tpinv(self.VaDE.sqrt_var_prior[i])
                    transform = transform[labels]
                    print(transform.shape)
                    latent_gmm = self.VaDE.mu_prior[labels] + torch.matmul(transform, torch.randn(n_data_points, self.args.latent_dim, 1)).reshape(n_data_points, self.args.latent_dim)
                else:
                    latent_gmm = self.VaDE.mu_prior[labels] + torch.randn(n_data_points, self.args.latent_dim)*np.exp(.5 * self.VaDE.log_var_prior[labels])
                print("-------------------")
                print("MEANS", self.VaDE.mu_prior)
                #print("VARIANCES", self.VaDE.log_var_prior)
                print("--------------------")
                
                z_vals = np.array(z_vals)
                x_hat_vals = np.array(x_hat_vals)
                y_pred = np.array(y_pred)
                x_vals = np.array(x_vals)
                
                #draw_all(z_vals, latent_gmm, x_vals, x_hat_vals, labels, y_pred, y_pred, epoch, self.args.output_dir)
                
                draw_together(z_vals, latent_gmm, x_vals, x_hat_vals, labels, y_pred, y_pred, epoch, self.args.output_dir)




    def compute_loss(self, x, x_hat, mu, log_var, z):
        p_c = self.VaDE.pi_prior
        gamma = self.compute_gamma(z, p_c)
        # print(x_hat, x)

        # print("xhat =", x_hat)
        # print("x = ", x)
        log_p_x_given_z = F.mse_loss(x_hat, x, reduction = 'sum') * 1. / (2 * self.args.noise_var)  # Use this for non-binary (also, I think L = 1)
        # print("vade loss =", log_p_x_given_z)
        # log_p_x_given_z = F.binary_cross_entropy(x_hat, x, reduction='sum')
        if self.covariance == 'full':
            h = (z.unsqueeze(1) - self.VaDE.mu_prior)
            h = torch.sum(torch.matmul(self.VaDE.sqrt_var_prior, h.permute(1, 2, 0)).permute(2, 0, 1).pow(2), dim = 2)+0.5*self.VaDE.var_log_det
            log_p_z_given_c = 0.5 * torch.sum(gamma * h)
        else:
            h = log_var.exp().unsqueeze(1) + (mu.unsqueeze(1) - self.VaDE.mu_prior).pow(2)
            h = torch.sum(self.VaDE.log_var_prior + h / self.VaDE.log_var_prior.exp(), dim=2)
            log_p_z_given_c = 0.5 * torch.sum(gamma * h)
        log_p_c = torch.sum(gamma * torch.log(p_c + 1e-9))
        log_q_c_given_x = torch.sum(gamma * torch.log(gamma + 1e-9))
        log_q_z_given_x = 0.5 * torch.sum(1 + log_var)

        loss = log_p_x_given_z + log_p_z_given_c - log_p_c +  log_q_c_given_x - log_q_z_given_x
        loss /= x.size(0)
        return loss

    def compute_gamma(self, z, p_c):
        if self.covariance == 'full':
            #samples = self.gmm.sample(1000)[0]
            #plt.scatter(samples[:, 0], samples[:, 1])
            #znp = np.asarray(z.detach())
            #plt.scatter(znp[:,0], znp[:, 1], s=20, c = "#2ca02c")
            #print(self.gmm.covariances_)
            #plt.show()
            h = (z.unsqueeze(1) - self.VaDE.mu_prior)
            h = torch.matmul(self.VaDE.sqrt_var_prior, h.permute(1, 2, 0)).permute(2, 0, 1).pow(2)
            h += torch.Tensor([np.log(np.pi*2)]).to(self.device)
            #print(torch.sum(h, dim=2))
            #print(torch.log(p_c + 1e-9).unsqueeze(0))
            #print(self.VaDE.var_log_det)
            p_z_c = torch.exp(torch.log(p_c + 1e-9).unsqueeze(0) - 0.5 * torch.sum(h, dim=2)+ 0.5*self.VaDE.var_log_det)  + 1e-9
            #print(p_z_c)
        else:
            h = (z.unsqueeze(1) - self.VaDE.mu_prior).pow(2) / self.VaDE.log_var_prior.exp()
            h += self.VaDE.log_var_prior
            h += torch.Tensor([np.log(np.pi*2)]).to(self.device)
            #print(h.shape)
            #print(h.shape)
            #print(torch.sum(h, dim=2))
            p_z_c = torch.exp(torch.log(p_c + 1e-9).unsqueeze(0) - 0.5 * torch.sum(h, dim=2)) + 1e-9
        gamma = p_z_c / torch.sum(p_z_c, dim=1, keepdim=True)
        #input()
        return gamma

    def cluster_acc(self, real, pred):
        D = max(pred.max(), real.max())+1
        w = np.zeros((D,D), dtype=np.int64)
        for i in range(pred.size):
            w[pred[i], real[i]] += 1
        ind = linear_assignment(w.max() - w)
        return sum([w[i,j] for i,j in ind])*1.0/pred.size*100, w



