import torch
import torch.nn.functional as F
from torch import nn
from torch.distributions import Categorical
import ot

def sample_gumbel(shape, eps=1e-10):
    U = torch.rand(shape, device="cuda")
    return -torch.log(-torch.log(U + eps) + eps)


def gumbel_softmax_sample(logits, temperature):
    g = sample_gumbel(logits.size())
    y = logits + g
    return F.softmax(y / temperature, dim=-1)


def calc_distance(z_continuous, codebook, dim_dict):
    z_continuous_flat = z_continuous.view(-1, dim_dict)
    distances = (torch.sum(z_continuous_flat**2, dim=1, keepdim=True) 
                + torch.sum(codebook**2, dim=1)
                - 2 * torch.matmul(z_continuous_flat, codebook.t()))

    return distances



class VectorQuantizer(nn.Module):
    def __init__(self, size_dict, dim_dict, temperature=0.5):
        super(VectorQuantizer, self).__init__()
        self.size_dict = size_dict
        self.dim_dict = dim_dict
        self.temperature = temperature
        self.beta = 0.25
    
    def forward(self, z_from_encoder, param_q, codebook, flg_train, flg_quant_det=False):
        return self._quantize(z_from_encoder, param_q, codebook,
                                flg_train=flg_train, flg_quant_det=flg_quant_det)
    
    def _quantize(self, z, var_q, codebook, flg_train, flg_quant_det=False):

        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.dim_dict)
        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z

        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(codebook**2, dim=1) - 2 * \
            torch.matmul(z_flattened, codebook.t())

        # find closest encodings
        min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
        min_encodings = torch.zeros(
            min_encoding_indices.shape[0], self.size_dict).to(z.device)
        min_encodings.scatter_(1, min_encoding_indices, 1)

        # get quantized latent vectors
        z_q = torch.matmul(min_encodings, codebook)
        z_q = z_q.view(z.shape)
        if flg_train:
            #loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
            loss = torch.mean((z_q.detach()-z)**2) + torch.mean((z_q - z.detach()) ** 2)
            loss = self.beta * torch.mean((z_q.detach()-z)**2) + torch.mean((z_q - z.detach()) ** 2)
        else:
            loss = 0.0
        # preserve gradients
        z_q = z + (z_q - z).detach()

        # perplexity
        e_mean = torch.mean(min_encodings, dim=0)
        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))

        # reshape back to match original input shape
        z_q = z_q.permute(0, 3, 1, 2).contiguous()

        return z_q, loss, perplexity
    

    def _inference(self, z, codebook):
        
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.dim_dict)
        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z

        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(codebook**2, dim=1) - 2 * \
            torch.matmul(z_flattened, codebook.t())

        # find closest encodings
        min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
        min_encodings = torch.zeros(
            min_encoding_indices.shape[0], self.size_dict).to(z.device)
        min_encodings.scatter_(1, min_encoding_indices, 1)

        # get quantized latent vectors
        z_q = torch.matmul(min_encodings, codebook)

        # perplexity
        e_mean = torch.mean(min_encodings, dim=0)
        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))

        z_q = z_q.view(z.shape)
        # reshape back to match original input shape
        z_q = z_q.permute(0, 3, 1, 2).contiguous()
        return z_q, min_encodings, min_encoding_indices, perplexity

    def set_temperature(self, value):
        self.temperature = value

    def set_temperature(self, value):
        self.temperature = value
    
    def _calc_distance_bw_enc_codes(self):
        raise NotImplementedError()
    
    def _calc_distance_bw_enc_dec(self):
        raise NotImplementedError()


class KAN_NET(nn.Module):
    def __init__(self, batch_size):
        super(KAN_NET, self).__init__()
        self.uniform_mass = torch.Tensor([1 / batch_size]).repeat(batch_size)
        self.kan_param = nn.Parameter(self.uniform_mass)


class FWSVectorQuantizer(VectorQuantizer):
    def __init__(self, size_dict, dim_dict, temperature=0.5):
        super(FWSVectorQuantizer, self).__init__(size_dict, dim_dict, temperature)
        self.size_dict = size_dict
        self.dim_dict = dim_dict
        self.temperature = temperature
        self.beta = 1.0
        self.kan_net = KAN_NET(self.size_dict)
        self.optim_kan = torch.optim.Adam(
            self.kan_net.parameters(),
            lr=0.1,
            weight_decay=0.1,
            amsgrad=True,
        )

        self.kan_net2 = KAN_NET(2048)
        self.optim_kan2 = torch.optim.Adam(
            self.kan_net2.parameters(),
            lr=0.1,
            weight_decay=0.1,
            amsgrad=True,
        )
        self.theta = 0.1
        self.phi_net_troff = 1.0

    def reset_kan_net(self):
        self.kan_net = KAN_NET(self.size_dict)
        self.kan_net.cuda()
        self.kan_net2 = KAN_NET(2048)
        self.kan_net2.cuda()


    def forward(self, z_from_encoder, codebook, flg_train):
        return self._quantize(z_from_encoder,codebook, flg_train=flg_train)
    

    def compute_OT_loss(self, ot_cost, kan_net):
        # phi_network = kan_net(x2).reshape(-1)  # Use this if \phi is a network
        phi_network = kan_net.kan_param

        exp_term = (- ot_cost + phi_network) / self.theta
        phi_loss = torch.mean(phi_network)

        OT_loss = torch.mean(
            - self.theta * (torch.log(torch.tensor(1.0 / self.size_dict)) + torch.logsumexp(exp_term, dim=1))
        ) + self.phi_net_troff * phi_loss
        return OT_loss

    def _quantize(self, z, codebook, flg_train):
        #import pdb; pdb.set_trace()
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.dim_dict)
        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z

        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(codebook**2, dim=1) - 2 * \
            torch.matmul(z_flattened, codebook.t())

        # find closest encodings
        min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
        min_encodings = torch.zeros(
            min_encoding_indices.shape[0], self.size_dict).to(z.device)
        min_encodings.scatter_(1, min_encoding_indices, 1)

        # get quantized latent vectors
        z_q = torch.matmul(min_encodings, codebook)


        if flg_train:
            loss1 = -self.compute_OT_loss(d, self.kan_net)
            loss2 = -self.compute_OT_loss(d.t(), self.kan_net2)

            for i in range(0, 5):
                self.optim_kan.zero_grad()
                loss1.backward(retain_graph=True)
                self.optim_kan.step()
                self.optim_kan2.zero_grad()
                loss2.backward(retain_graph=True)
                self.optim_kan2.step()
            # size_batch = z_flattened.shape[0]
            # sample_weight = torch.ones(size_batch).to(z.device) / size_batch
            # codeword_weight = torch.ones(self.size_dict).to(z.device) / self.size_dict
            # centroids = torch.matmul(torch.diag(torch.ones(self.size_dict)).to(z.device), codebook)
            # loss = self.beta * ot.emd2(codeword_weight, sample_weight, d.t(), numItermax=500000)
            
            loss1 = self.compute_OT_loss(d, self.kan_net)
            loss2 = self.compute_OT_loss(d.t(), self.kan_net2)
            loss = loss1 + loss2
            self.reset_kan_net
        else:
            loss = 0.0
        
        z_q = z_q.view(z.shape)
        # preserve gradients
        z_q = z + (z_q - z).detach()

        # perplexity
        e_mean = torch.mean(min_encodings, dim=0)
        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))

        # reshape back to match original input shape
        z_q = z_q.permute(0, 3, 1, 2).contiguous()

        return z_q, loss, perplexity
    

    def _inference(self, z, codebook):
        
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.dim_dict)
        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z

        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(codebook**2, dim=1) - 2 * \
            torch.matmul(z_flattened, codebook.t())

        # find closest encodings
        min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
        min_encodings = torch.zeros(
            min_encoding_indices.shape[0], self.size_dict).to(z.device)
        min_encodings.scatter_(1, min_encoding_indices, 1)

        # get quantized latent vectors
        z_q = torch.matmul(min_encodings, codebook)

        # perplexity
        e_mean = torch.mean(min_encodings, dim=0)
        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))

        z_q = z_q.view(z.shape)
        # reshape back to match original input shape
        z_q = z_q.permute(0, 3, 1, 2).contiguous()
        return z_q, min_encodings, min_encoding_indices, perplexity

    def set_temperature(self, value):
        self.temperature = value

class WSVectorQuantizer(VectorQuantizer):
    def __init__(self, size_dict, dim_dict, temperature=0.5):
        super(WSVectorQuantizer, self).__init__(size_dict, dim_dict, temperature)
        self.size_dict = size_dict
        self.dim_dict = dim_dict
        self.temperature = temperature
        self.beta = 1.0


    def forward(self, z_from_encoder, codebook, flg_train):
        return self._quantize(z_from_encoder,codebook, flg_train=flg_train)
    
    def _quantize(self, z, codebook, flg_train):
        #import pdb; pdb.set_trace()
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.dim_dict)
        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z

        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(codebook**2, dim=1) - 2 * \
            torch.matmul(z_flattened, codebook.t())

        # find closest encodings
        min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
        min_encodings = torch.zeros(
            min_encoding_indices.shape[0], self.size_dict).to(z.device)
        min_encodings.scatter_(1, min_encoding_indices, 1)

        # get quantized latent vectors
        z_q = torch.matmul(min_encodings, codebook)


        if flg_train:
            size_batch = z_flattened.shape[0]
            sample_weight = torch.ones(size_batch).to(z.device) / size_batch
            codeword_weight = torch.ones(self.size_dict).to(z.device) / self.size_dict

            centroids = torch.matmul(torch.diag(torch.ones(self.size_dict)).to(z.device), codebook)
            loss = self.beta * ot.emd2(codeword_weight, sample_weight, d.t(), numItermax=500000)
            
        else:
            loss = 0.0
        
        z_q = z_q.view(z.shape)
        # preserve gradients
        z_q = z + (z_q - z).detach()

        # perplexity
        e_mean = torch.mean(min_encodings, dim=0)
        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))

        # reshape back to match original input shape
        z_q = z_q.permute(0, 3, 1, 2).contiguous()

        return z_q, loss, perplexity
    

    def _inference(self, z, codebook):
        
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.dim_dict)
        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z

        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(codebook**2, dim=1) - 2 * \
            torch.matmul(z_flattened, codebook.t())

        # find closest encodings
        min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
        min_encodings = torch.zeros(
            min_encoding_indices.shape[0], self.size_dict).to(z.device)
        min_encodings.scatter_(1, min_encoding_indices, 1)

        # get quantized latent vectors
        z_q = torch.matmul(min_encodings, codebook)

        # perplexity
        e_mean = torch.mean(min_encodings, dim=0)
        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))

        z_q = z_q.view(z.shape)
        # reshape back to match original input shape
        z_q = z_q.permute(0, 3, 1, 2).contiguous()
        return z_q, min_encodings, min_encoding_indices, perplexity

    def set_temperature(self, value):
        self.temperature = value

class GaussianVectorQuantizer(VectorQuantizer):
    def __init__(self, size_dict, dim_dict, temperature=0.5, param_var_q="gaussian_1"):
        super(GaussianVectorQuantizer, self).__init__(size_dict, dim_dict, temperature)
        self.param_var_q = param_var_q
    
    def _quantize(self, z_from_encoder, var_q, codebook, flg_train=True, flg_quant_det=False):
        bs, dim_z, width, height = z_from_encoder.shape
        z_from_encoder_permuted = z_from_encoder.permute(0, 2, 3, 1).contiguous()
        precision_q = 1. / torch.clamp(var_q, min=1e-10)

        logit = -self._calc_distance_bw_enc_codes(z_from_encoder_permuted, codebook, 0.5 * precision_q)
        probabilities = torch.softmax(logit, dim=-1)
        log_probabilities = torch.log_softmax(logit, dim=-1)
        
        # Quantization
        if flg_train:
            encodings = gumbel_softmax_sample(logit, self.temperature)
            z_quantized = torch.mm(encodings, codebook).view(bs, width, height, dim_z)
            avg_probs = torch.mean(probabilities.detach(), dim=0)
        else:
            if flg_quant_det:
                indices = torch.argmax(logit, dim=1).unsqueeze(1)
                encodings_hard = torch.zeros(indices.shape[0], self.size_dict, device="cuda")
                encodings_hard.scatter_(1, indices, 1)
                avg_probs = torch.mean(encodings_hard, dim=0)
            else:
                dist = Categorical(probabilities)
                indices = dist.sample().view(bs, width, height)
                encodings_hard = F.one_hot(indices, num_classes=self.size_dict).type_as(codebook)
                avg_probs = torch.mean(probabilities, dim=0)
            z_quantized = torch.matmul(encodings_hard, codebook).view(bs, width, height, dim_z)
        z_to_decoder = z_quantized.permute(0, 3, 1, 2).contiguous()
        
        # Latent loss
        kld_discrete = torch.sum(probabilities * log_probabilities, dim=(0,1)) / bs
        kld_continuous = self._calc_distance_bw_enc_dec(z_from_encoder, z_to_decoder, 0.5 * precision_q).mean()
        loss = kld_discrete + kld_continuous
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-7)))

        return z_to_decoder, loss, perplexity


    def _inference(self, z_from_encoder, var_q, codebook, flg_train=True, flg_quant_det=False):
        bs, dim_z, width, height = z_from_encoder.shape
        z_from_encoder_permuted = z_from_encoder.permute(0, 2, 3, 1).contiguous()
        precision_q = 1. / torch.clamp(var_q, min=1e-10)

        logit = -self._calc_distance_bw_enc_codes(z_from_encoder_permuted, codebook, 0.5 * precision_q)
        
        min_encoding_indices = torch.argmax(logit, dim=1).unsqueeze(1)
        min_encodings = torch.zeros(min_encoding_indices.shape[0], self.size_dict, device="cuda")
        min_encodings.scatter_(1, min_encoding_indices, 1)
        
        z_quantized = torch.matmul(min_encodings, codebook).view(bs, width, height, dim_z)
        z_to_decoder = z_quantized.permute(0, 3, 1, 2).contiguous()

        # Latent loss
        probabilities = torch.softmax(logit, dim=-1)
        log_probabilities = torch.log_softmax(logit, dim=-1)
        kld_discrete = torch.sum(probabilities * log_probabilities, dim=(0,1)) / bs
        kld_continuous = self._calc_distance_bw_enc_dec(z_from_encoder, z_to_decoder, 0.5 * precision_q).mean()
        loss = kld_discrete + kld_continuous
        
        return z_to_decoder, min_encodings, min_encoding_indices, loss

    def _calc_distance_bw_enc_codes(self, z_from_encoder, codebook, weight):        
        if self.param_var_q == "gaussian_1":
            distances = weight * calc_distance(z_from_encoder, codebook, self.dim_dict)
        elif self.param_var_q == "gaussian_2":
            weight = weight.tile(1, 1, 8, 8).view(-1,1)
            distances = weight * calc_distance(z_from_encoder, codebook, self.dim_dict)
        elif self.param_var_q == "gaussian_3":
            weight = weight.view(-1,1)
            distances = weight * calc_distance(z_from_encoder, codebook, self.dim_dict)
        elif self.param_var_q == "gaussian_4":
            z_from_encoder_flat = z_from_encoder.view(-1, self.dim_dict).unsqueeze(2)
            codebook = codebook.t().unsqueeze(0)
            weight = weight.permute(0, 2, 3, 1).contiguous().view(-1, self.dim_dict).unsqueeze(2)
            distances = torch.sum(weight * ((z_from_encoder_flat - codebook) ** 2), dim=1)

        return distances
        
    def _calc_distance_bw_enc_dec(self, x1, x2, weight):
        return torch.sum((x1-x2)**2 * weight, dim=(1,2,3))
    


class VmfVectorQuantizer(VectorQuantizer):
    def __init__(self, size_dict, dim_dict, temperature=0.5):
        super(VmfVectorQuantizer, self).__init__(size_dict, dim_dict, temperature)
    
    def _quantize(self, z_from_encoder, kappa_q, codebook, flg_train=True, flg_quant_det=False):
        bs, dim_z, width, height = z_from_encoder.shape
        z_from_encoder_permuted = z_from_encoder.permute(0, 2, 3, 1).contiguous()
        codebook_norm = F.normalize(codebook, p=2.0, dim=1)

        logit = -self._calc_distance_bw_enc_codes(z_from_encoder_permuted, codebook_norm, kappa_q)
        probabilities = torch.softmax(logit, dim=-1)
        log_probabilities = torch.log_softmax(logit, dim=-1)
        
        # Quantization
        if flg_train:
            encodings = gumbel_softmax_sample(logit, self.temperature)
            z_quantized = torch.mm(encodings, codebook_norm).view(bs, width, height, dim_z)
            avg_probs = torch.mean(probabilities.detach(), dim=0)
        else:
            if flg_quant_det:
                indices = torch.argmax(logit, dim=1).unsqueeze(1)
                encodings_hard = torch.zeros(indices.shape[0], self.size_dict, device="cuda")
                encodings_hard.scatter_(1, indices, 1)
                avg_probs = torch.mean(encodings_hard, dim=0)
            else:
                dist = Categorical(probabilities)
                indices = dist.sample().view(bs, width, height)
                encodings_hard = F.one_hot(indices, num_classes=self.size_dict).type_as(codebook)
                avg_probs = torch.mean(probabilities, dim=0)
            z_quantized = torch.matmul(encodings_hard, codebook_norm).view(bs, width, height, dim_z)
        z_to_decoder = z_quantized.permute(0, 3, 1, 2).contiguous()

        # Latent loss
        kld_discrete = torch.sum(probabilities * log_probabilities, dim=(0,1)) / bs
        kld_continuous = self._calc_distance_bw_enc_dec(z_from_encoder, z_to_decoder, kappa_q).mean()        
        loss = kld_discrete + kld_continuous
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-7)))

        return z_to_decoder, loss, perplexity
 
    def _calc_distance_bw_enc_codes(self, z_from_encoder, codebook, kappa_q):
        z_from_encoder_flat = z_from_encoder.view(-1, self.dim_dict)
        distances = -kappa_q * torch.matmul(z_from_encoder_flat, codebook.t())

        return distances
    
    def _calc_distance_bw_enc_dec(self, x1, x2, weight):
        return torch.sum(x1 * (x1-x2) * weight, dim=(1,2,3))
