import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import time
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import PIL 
from tqdm import tqdm
import matplotlib.pyplot as plt
import utils
from encoders import *
from decoders import *

class TaxonomicLayer(nn.Module):
    def __init__(self, n_hidden, n_clusters):
        super(TaxonomicLayer, self).__init__()
        self.n_hidden = n_hidden
        self.n_clusters = n_clusters

        self.cluster_weight = nn.Parameter(torch.zeros(n_clusters, 1))

        self.sigmoid = nn.Sigmoid()

    def kl_div(self, mean1, logvar1, mean2, logvar2):
        """
        Compute the KL divergence between two Gaussian distributions.
        Args:
            mean1 (Tensor): Mean of the first distribution.
            logvar1 (Tensor): Log variance of the first distribution.
            mean2 (Tensor): Mean of the second distribution.
            logvar2 (Tensor): Log variance of the second distribution.
        Returns:
            Tensor: KL divergence between the two distributions. Shape: (batch_size,)
        """
        return 0.5 * torch.sum(logvar2 - logvar1 + (logvar1.exp() + (mean1 - mean2)**2) / logvar2.exp() - 1, dim=1)

    def forward(self, mean_leaf, logvar_leaf, eps=1e-8):
        """
        Args:
            mean_leaf (Tensor): Leaf means with shape (n_clusters*2, n_hidden).
            var_leaf (Tensor): Leaf variances (diagonal covariances) with shape (n_clusters*2, n_hidden).
        
        Returns:
            mean_parent (Tensor): Parent means with shape (n_clusters, n_hidden).
            var_parent (Tensor): Parent variances with shape (n_clusters, n_hidden).
        """
        # alpha = self.sigmoid(self.cluster_weight) # shape: n_clusters, 1
        # want a more flat distribution. What to do?
        # use a temprature to control the flatness of the distribution
        alpha = self.sigmoid(self.cluster_weight) # shape: n_clusters, 1
        alpha = torch.cat((alpha, 1-alpha), dim=1).unsqueeze(-1) # shape: n_clusters, 2, 1

        mean_leaf_expanded = mean_leaf.view(-1, 2, self.n_hidden) # shape: n_clusters, 2, n_hidden
        logvar_leaf_expanded = logvar_leaf.view(-1, 2, self.n_hidden) # shape: n_clusters, 2, n_hidden

        # pick the two children per cluster
        m_child0, m_child1 = mean_leaf_expanded[:,0], mean_leaf_expanded[:,1]      # each (n_clusters, n_hidden)
        lv_child0, lv_child1 = logvar_leaf_expanded[:,0], logvar_leaf_expanded[:,1]

        # forward and reverse KL
        kl_0_1 = self.kl_div(m_child0, lv_child0, m_child1, lv_child1)  # (n_clusters,)
        kl_1_0 = self.kl_div(m_child1, lv_child1, m_child0, lv_child0)  # (n_clusters,)
        # symmetrized KL and reshape to (n_clusters, 1)
        sym_kl = (kl_0_1 + kl_1_0).unsqueeze(1)  # (n_clusters, 1)

        mean_root = (alpha * mean_leaf_expanded).sum(dim=1) # shape: n_clusters, n_hidden
        var_leaf = torch.exp(logvar_leaf_expanded)

        mean_root_expand = mean_root.unsqueeze(1) # shape: n_clusters, 1, n_hidden
        diff_sq = (mean_leaf_expanded - mean_root_expand)**2 # shape: n_clusters, 2, n_hidden
        var_root = (alpha * (var_leaf + diff_sq)).sum(dim=1) # shape: n_clusters, n_hidden

        logvar_root = torch.log(var_root + eps) # shape : n_clusters, n_hidden

        # compute DKL between leaf and parent
        mean_root_dkl = mean_root.unsqueeze(1).expand(-1, 2, -1).reshape(-1, self.n_hidden) # shape: n_clusters*2, n_hidden
        # print(mean_root_dkl)
        # print(mean_leaf.shape)
        logvar_root_dkl = logvar_root.unsqueeze(1).expand(-1, 2, -1).reshape(-1, self.n_hidden) # shape: n_clusters*2, n_hidden

        dkl = 0.5 * torch.sum(
            logvar_root_dkl - logvar_leaf - 1 + torch.exp(logvar_leaf - logvar_root_dkl) + (mean_leaf - mean_root_dkl)**2 / torch.exp(logvar_root_dkl),
            dim=1
        ) # shape: n_clusters*2, n_hidden
        return mean_root, logvar_root, alpha.reshape(-1), sym_kl # shape: 2 * n_clusters, n_hidden

class DeepTaxonNet(nn.Module):
    def __init__(self, 
                 input_dim=3*32*32, 
                 enc_hidden_dim=128,
                 dec_hidden_dim=128, # tuple for ResNet
                 latent_dim=32, 
                 n_layers=5, 
                 encoder_name=None, # string
                 decoder_name=None, # string
                 kl1_weight=1.0,
                 recon_weight=1.0,
                 noise_strength=0.0,
                 dkl_margin=0.0,
                 dkl_weight_lambda=0.0,
                 convex_weight_lambda=0.0,
                 logvar_init_range=-4,
                 ):
        super(DeepTaxonNet, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.enc_hidden_dim = enc_hidden_dim
        self.dec_hidden_dim = dec_hidden_dim
        self.n_layers = n_layers
        self.encoder_name = encoder_name
        self.decoder_name = decoder_name
        self.kl1_weight = kl1_weight
        self.recon_weight = recon_weight
        self.noise_strength = 0.0
        self.dkl_margin = dkl_margin
        self.dkl_weight_lambda = dkl_weight_lambda
        self.convex_weight_lambda = convex_weight_lambda
        self.logvar_init_range = logvar_init_range

        ## Encoder
        if self.encoder_name == "mnist":
            self.encoder = Encoder28x28()
        elif self.encoder_name == "omniglot":
            self.encoder = OmniglotEncoder()
        elif self.encoder_name == "resnet":
            size = int((self.input_dim / 3)**0.5)
            self.encoder = Resnet_Encoder(s0=4, nf=32, nf_max=256, size=size)
        else:
            raise ValueError("Unknown encoder type")
        
        ## Decoder
        if self.decoder_name == "mnist":
            self.decoder_raw = Decoder28x28()
        elif self.decoder_name == "omniglot":
            self.decoder_raw = OmniglotDecoder()
        elif self.decoder_name == "resnet":
            size = int((self.input_dim / 3)**0.5)
            self.decoder_raw = Resnet_Decoder(s0=4, nf=32, nf_max=256, size=size)
        else:
            raise ValueError("Unknown decoder type")

        ## Projection heads
        self.fc_mu = nn.Sequential(
            nn.Flatten(),
            nn.LayerNorm(self.enc_hidden_dim),
            nn.Linear(self.enc_hidden_dim, self.latent_dim),
        )
        self.fc_logvar = nn.Sequential(
            nn.Flatten(),
            nn.LayerNorm(self.enc_hidden_dim),
            nn.Linear(self.enc_hidden_dim, self.latent_dim),
        )

        self.contrastive_projection = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.enc_hidden_dim, 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Linear(512, 64),
        )

        # self.fc_mu.apply(lambda m: self.layer_init(m, 1e-3))
        # self.fc_logvar.apply(lambda m: self.layer_init(m, 5e-3))

        ## De-Projection heads
        self.decoder = nn.Sequential(
            nn.Linear(self.latent_dim, self.dec_hidden_dim[0] * self.dec_hidden_dim[1] * self.dec_hidden_dim[2]),
            nn.ReLU(),
            nn.Unflatten(1, (self.dec_hidden_dim[0], self.dec_hidden_dim[1], self.dec_hidden_dim[2])),
            self.decoder_raw
        )

        self.pcx_projection = nn.Sequential(
            nn.Linear(2**(self.n_layers+1)-1, 64),
        )

        # init_range = 1e-5  # or compute your desired range based on a fan_in value
        # first_linear = self.decoder[0]  # Access the first element, which is the nn.Linear layer

        # nn.init.uniform_(first_linear.weight, -init_range, init_range)
        # if first_linear.bias is not None:
        #     nn.init.constant_(first_linear.bias, 0.0)

        # GMM Prior Parameters:
        # 1. Cluster prior logits; softmax gives pi (mixing coefficients)
        self.pi_logits = nn.Parameter(torch.zeros(2**(self.n_layers+1)-1))  # Init 0 => softmax gives uniform distribution # shape: (n_clusters,)
        # 2. Cluster means: shape (n_clusters, latent_dim)
        limit = 1 / (2 ** self.n_layers)
        # limit = 0.1
        # limit = 1
        # limit = 5
        # limit = 0
        # limit = np.sqrt(7)
        self.mu_c = nn.Parameter(torch.nn.init.uniform_(torch.empty(2**self.n_layers, self.latent_dim), -limit, limit))  # Init uniform
        # initialize with normal distribution
        # self.mu_c = nn.Parameter(torch.nn.init.normal_(torch.empty(2**self.n_layers, self.latent_dim), 0, 0.6))  # Init uniform

        limit = self.logvar_init_range
        self.logvar_c = nn.Parameter(torch.nn.init.uniform_(torch.empty(2**self.n_layers, self.latent_dim), limit, limit+1))  # Init uniform
        # self.mu_c = nn.Parameter(torch.zeros(2**self.n_layers, self.latent_dim))  # Init uniform

        # self.logvar_c = nn.Parameter(torch.zeros(2**self.n_layers, self.latent_dim))  # Init uniform
        self.layers = nn.ModuleList(
            [TaxonomicLayer(self.latent_dim, 2**i) for i in reversed(range(0, self.n_layers))]
        ) 

    def pretrain_autoencoder(self, x):
        h = self.encoder(x)
        x_recon = self.decoder(h)
        recon_loss = F.mse_loss(x_recon, x)
        return recon_loss

    def encode(self, x):
        # x = self.pretrained(x)
        h = self.encoder(x)
        # print(f"encoder output range: {h.min()} {h.max()}")
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        # print(f"logvar range: {logvar.min()} {logvar.max()}")
        return mu, logvar   
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        # if not self.training:
        #     # print("eval")
        #     eps = torch.zeros_like(std)
        return mu + eps * std
    
    def decode(self, z):
        x = self.decoder(z)
        # if self.decoder_name == "mlp":
            # x = x.view(-1, 3, 32, 32)
        return x
    

    def forward(self, x, tau=1.0):
        z_contrastive = self.contrastive_projection(self.encoder(x))
        mu, logvar = self.encode(x) # shape: (batch_size, latent_dim)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)

        # ELBO Loss
        pi, mu_c, logvar_c, H, alphas, dkl_list = self.gmm_params() # shape: (n_clusters, latent_dim)


        ########################################################
        # 1. reconstruction loss
        ##########################################################
        recon_loss = F.mse_loss(x_recon, x)
        recon_loss = recon_loss * self.input_dim
        recon_loss = recon_loss * self.recon_weight

        #####################################################
        # 1.5 calculate q(c|x) as approximated by p(c|z)
        #####################################################
        # p(c|z) is p(c)p(z|c)/sum(p(z|c)p(c))
        z_l = self.reparameterize(mu, logvar) # sample a new z
        log_pdf = self.gaussian_pdf(z_l, mu_c, logvar_c) # shape: (batch_size, n_clusters). This is p(z|c)
        # tau = 0.8
        if not self.training:
            # print("eval")
            alpha = 0.0
        else:
            alpha = self.noise_strength
        logpcpzc = (log_pdf + torch.log(pi.unsqueeze(0))) # log p(c|z) = log p(c) + log p(z|c)
        pcx, logpcx = utils.GumbelSoftmax(logpcpzc, # Gumbel softmax for exploration
                                            tau=tau, 
                                            alpha=alpha,
                                            hard=False, # MUST be False
                                            dim=-1)

        #############################################################
        # 2: KL divergence between q(z|x) and p(z|c)
        #############################################################
        # q(z|x) and log p(z|c)
        qzxlogpzc = torch.sum( # summing over the latent dimensions
                    (logvar_c.unsqueeze(0) + torch.exp(logvar.unsqueeze(1) - logvar_c.unsqueeze(0)) +
                   (mu.unsqueeze(1) - mu_c.unsqueeze(0))**2 / torch.exp(logvar_c.unsqueeze(0))),
            dim=2
        ) # shape: (batch_size, n_clusters)
        # print(qzxlogpzc)
        # print(pcx)
        E_logpzc = -0.5 * torch.sum(qzxlogpzc * pcx, dim=1) # shape : (batch_size,) 
        # print(E_logpzc)
        E_logqzx = -torch.sum( # summing over the latent dimensions
            0.5 * (1 + logvar),
            dim=1
        ) # shape: (batch_size,)
        kl1 = torch.mean(E_logpzc - E_logqzx)
        kl1 = kl1 * self.kl1_weight

        ###############################################################
        # 3: KL divergence between p(c|x) and p(c)
        ###############################################################
        kl2 = torch.sum(
            pcx * (torch.log(pi.unsqueeze(0)) - logpcx),
            dim=1
        )
        kl2 = torch.mean(kl2)
        kl2 = kl2 * self.kl1_weight

        loss = recon_loss - kl1 - kl2

        pcx_contrastive = self.pcx_projection(pcx)

        ## Regularization

        # encourgae alphas to be uniform
        alphas = alphas / alphas.sum(dim=0, keepdim=True)
        # compare to uniform distribution
        U = torch.ones_like(alphas) / alphas.shape[0]

        # compute the kl between the alpha distribution and the uniform distribution
        kl_alpha = alphas * (torch.log(alphas + 1e-10) - torch.log(U + 1e-10))
        kl_alpha_weights = utils.convex_weight_decay(self.n_layers, self.convex_weight_lambda)
        kl_alpha = kl_alpha * kl_alpha_weights.to(alphas.device)
        kl_alpha = kl_alpha.sum(dim=0)

        loss = loss + kl_alpha


        # # penalize the high entropy of the leave node
        dkl_list_flatten = torch.cat(dkl_list, dim=0) / 2# shape: (2 ** (n_layers + 1) - 2, n_clusters)
        margin = self.dkl_margin
        margin = utils.dkl_weight_warmup(self.n_layers, margin, self.dkl_weight_lambda)
        dkl_penalty = torch.relu(margin.to(dkl_list_flatten.device) - dkl_list_flatten.squeeze(1))

        loss = loss + dkl_penalty.sum()

        return loss, recon_loss, kl1, kl2, H, pcx, pi, dkl_list, z_contrastive, pcx_contrastive
    
    def gmm_params(self):   # Return the cluster prior probabilities (pi), means, and log variances.
        layer_mu = []
        layer_logvar = []
        pi_list = []
        dkl_list = []

        layer_mu.append(self.mu_c)
        layer_logvar.append(self.logvar_c)

        parent_root = self.mu_c
        parent_logvar = self.logvar_c

        for i, layer in enumerate(self.layers):
            parent_root, parent_logvar, alpha, dkl = layer(parent_root, parent_logvar)
            layer_mu.append(parent_root)
            layer_logvar.append(parent_logvar)
            pi_list.append(alpha)
            dkl_list.append(dkl)

        alphas = torch.cat(pi_list, dim=0) # shape: (2 ** (n_layers + 1) - 1, n_clusters)

        # get entropy at each layer
        H = []
        for logvar_layer in layer_logvar:
            H.append(0.5 * torch.sum(logvar_layer, dim=1))

        # Concatenate the means and log variances from all layers
        mu_c = torch.cat(layer_mu, dim=0)  # shape: (2 ** (n_layers + 1) - 1, latent_dim)
        logvar_c = torch.cat(layer_logvar, dim=0)

        N_NODES = 2**(self.n_layers+1)-1
        pi = torch.tensor([1/N_NODES] * N_NODES, device=mu_c.device)

        return pi, mu_c, logvar_c, H, alphas, dkl_list
    
    def gaussian_pdf(self, x, mu, logvar):
        """
        Compute the Gaussian PDF N(x|mu, var) for each dimension of the latent space.
        Args:
            x (Tensor): Input tensor of shape (batch_size, latent_dim).
            mu (Tensor): Mean tensor of shape (n_clusters, latent_dim).
            logvar (Tensor): Log variance tensor of shape (n_clusters, latent_dim).
        Returns:
            Tensor: Gaussian PDF N(x|mu, var) for each dimension of the latent space, shape (batch_size, n_clusters).
                    Meaning the logprob of each cluster for each sample.
        """
        var = torch.exp(logvar) # shape (n_clusters, latent_dim)
        # check if var is zero
            # var = var + 1e-10
        logpdf = -0.5 * (np.log(2*np.pi) + logvar.unsqueeze(0) + (x.unsqueeze(1) - mu.unsqueeze(0))**2 / var.unsqueeze(0))
        return logpdf.sum(-1)
    
    def vae_forward(self, x): # for testing
        # VAE loss
        mu, logvar = self.encode(x) # shape: (batch_size, latent_dim)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)

        loss_recon = F.mse_loss(x_recon, x)
        loss_recon = loss_recon * self.input_dim
        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        loss = loss_recon + kld
        return loss, loss_recon, kld, 0