import torch
import torch.nn.functional as F
from torch import nn
from torch.distributions import Categorical
import ot


def sample_gumbel(shape, eps=1e-10):
    U = torch.rand(shape, device="cuda")
    return -torch.log(-torch.log(U + eps) + eps)


def gumbel_softmax_sample(logits, temperature):
    g = sample_gumbel(logits.size())
    y = logits + g
    return F.softmax(y / temperature, dim=-1)




class OTVectorQuantizer(nn.Module):
    def __init__(self, size_dict, dim_dict, eps, ot_iter):
        super(OTVectorQuantizer, self).__init__()
        
        self.eps = eps
        self.ot_iter = ot_iter
        self.size_dict =size_dict
        self.dim_dict =dim_dict
        self.temperature = 1
        
        
    def set_temperature(self, value):
        self.temperature = value
        
    def forward(self, z_from_encoder, codebook, temp, flg_train=True, flg_quant_det=False):
        bs, dim_z, width, height = z_from_encoder.shape
        z_from_encoder_permuted = z_from_encoder.permute(0, 2, 3, 1).contiguous()
        
        z_flat = z_from_encoder_permuted.view(-1, self.dim_dict)
        codebook_norm = F.normalize(codebook, dim=-1)
        logit = z_flat.mm(codebook_norm.T) * temp.exp()
        
        probabilities = torch.softmax(logit, dim=-1)
        log_probabilities = torch.log_softmax(logit, dim=-1)
        with torch.no_grad():
            q_ot = ot.log_sinkhorn(logit, eps=self.eps, max_iter=self.ot_iter)
        
        # Quantization
        if flg_train:
            encodings = gumbel_softmax_sample(logit, self.temperature)
            z_quantized = torch.mm(encodings, codebook).view(bs, width, height, dim_z)
            avg_probs = torch.mean(probabilities.detach(), dim=0)
        else:
            if flg_quant_det:
                indices = torch.argmax(logit, dim=1).unsqueeze(1)
                encodings_hard = torch.zeros(indices.shape[0], self.size_dict, device="cuda")
                encodings_hard.scatter_(1, indices, 1)
                avg_probs = torch.mean(encodings_hard, dim=0)
            else:
                dist = Categorical(probabilities)
                indices = dist.sample().view(bs, width, height)
                encodings_hard = F.one_hot(indices, num_classes=self.size_dict).type_as(codebook)
                avg_probs = torch.mean(probabilities, dim=0)
            z_quantized = torch.matmul(encodings_hard, codebook).view(bs, width, height, dim_z)
        z_to_decoder = z_quantized.permute(0, 3, 1, 2).contiguous()
        
        #loss = -torch.sum(q_ot * log_probabilities) / bs / self.size_dict
        loss = -torch.mean(q_ot * log_probabilities)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-7)))

        return z_to_decoder, loss, perplexity

    
    
    