import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch import einsum
from einops import rearrange
from torch import distributed as tdist

class KAN_NET(nn.Module):
    def __init__(self, batch_size):
        super(KAN_NET, self).__init__()
        self.uniform_mass = torch.Tensor([1 / batch_size]).repeat(batch_size)
        self.kan_param = nn.Parameter(self.uniform_mass)

class VQWAE(nn.Module):
    def __init__(self, args):
        super(VQWAE, self).__init__()
        self.codebook_size = args.codebook_size
        self.codebook_dim = args.codebook_dim
        self.temperature = 0.5
        self.beta = args.beta
        self.alpha = args.alpha

        self.kan_net = KAN_NET(self.codebook_size)
        self.optim_kan = torch.optim.Adam(
            self.kan_net.parameters(),
            lr=0.1,
            weight_decay=0.1,
            amsgrad=True,
        )

        self.data_size = args.batch_size * int(args.resolution/args.factor) * int(args.resolution/args.factor)
        self.kan_net2 = KAN_NET(self.data_size)
        self.optim_kan2 = torch.optim.Adam(
            self.kan_net2.parameters(),
            lr=0.1,
            weight_decay=0.1,
            amsgrad=True,
        )
        self.theta = 0.1
        self.phi_net_troff = 1.0

        self.embedding = nn.Embedding(self.codebook_size, self.codebook_dim)
        self.embedding.weight.data.uniform_(-1.0 /self.codebook_size, 1.0/self.codebook_size)

    def reset_kan_net(self):
        self.kan_net = KAN_NET(self.codebook_size)
        self.kan_net.cuda()
        self.kan_net2 = KAN_NET(self.data_size)
        self.kan_net2.cuda()

    def compute_OT_loss(self, ot_cost, kan_net):
        # phi_network = kan_net(x2).reshape(-1)  # Use this if \phi is a network
        phi_network = kan_net.kan_param

        exp_term = (- ot_cost + phi_network) / self.theta
        phi_loss = torch.mean(phi_network)

        OT_loss = torch.mean(
            - self.theta * (torch.log(torch.tensor(1.0 / self.codebook_size)) + torch.logsumexp(exp_term, dim=1))
        ) + self.phi_net_troff * phi_loss
        return OT_loss

    def forward(self, z):
        # reshape z -> (batch, height, width, channel) and flatten
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.codebook_dim)

        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight**2, dim=1) - 2 * \
            torch.matmul(z_flattened, self.embedding.weight.t())

        # find closest encodings
        min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
        min_encodings = torch.zeros(
            min_encoding_indices.shape[0], self.codebook_size).to(z.device)
        min_encodings.scatter_(1, min_encoding_indices, 1)

        # get quantized latent vectors
        z_q = torch.matmul(min_encodings, self.embedding.weight)

        if self.training:
            for i in range(0, 5):
                loss1 = -self.compute_OT_loss(d, self.kan_net)
                self.optim_kan.zero_grad()
                loss1.backward(retain_graph=True)
                self.optim_kan.step()

                loss2 = -self.compute_OT_loss(d.t(), self.kan_net2)
                self.optim_kan2.zero_grad()
                loss2.backward(retain_graph=True)
                self.optim_kan2.step()

            loss1 = self.compute_OT_loss(d, self.kan_net)
            loss2 = self.compute_OT_loss(d.t(), self.kan_net2)
            loss = loss1 + loss2
            self.reset_kan_net()

        z_q = z_q.view(z.shape)

        # preserve gradients
        z_q = z + (z_q - z).detach()

        ## Criterion Triple defined in the paper
        quant_error = (z_q.detach()-z.detach()).square().sum(3).mean()

        token = torch.argmin(d, dim=1)
        histogram = token.bincount(minlength=self.codebook_size).float()
        codebook_usage_counts = (histogram > 0).float().sum()
        codebook_utilization = codebook_usage_counts.item() / self.codebook_size
            
        avg_probs = histogram/histogram.sum(0)
        codebook_perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # reshape back to match original input shape
        z_q = z_q.permute(0, 3, 1, 2).contiguous()

        return z_q, loss, quant_error, codebook_utilization, codebook_perplexity
    
    def collect_eval_info(self, z):
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.codebook_dim)

        # distances from z to embeddings
        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight.data**2, dim=1) - 2 * \
            torch.matmul(z_flattened, self.embedding.weight.data.t())

        token = torch.argmin(d, dim=1)
        z_q = self.embedding(token).view(z.shape)

        quant_error = (z_q.detach()-z.detach()).square().sum(3).mean()

        histogram = token.bincount(minlength=self.codebook_size).float()
        # reshape back to match original input shape
        z_q = z_q.permute(0, 3, 1, 2).contiguous()

        return z_q, quant_error, histogram

    def obtain_embedding_id(self, z):
        b, c, h, w = z.shape
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.codebook_dim)

        # distances from z to embeddings
        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight.data**2, dim=1) - 2 * \
            torch.matmul(z_flattened, self.embedding.weight.data.t())

        token = torch.argmin(d, dim=1)
        return token.view(b, h, w)
    
    def obtain_codebook_entry(self, indices):
        return self.embedding(indices)  ## (b,h,w,c)