
import torch
# import numpy as np
from torch import nn
from torch.nn import functional as F


class Quantize(nn.Module):

    """
    Quantization operation for VQ-VAE. Also supports Tensor Quantization

    Args:
        dim (int)         : dimensionality of each latent vector (D in paper)
        num_embeddings    : number of embedding in codebook (K in paper)
        size (int tuple)  : height and dim of each quantized tensor.
                            Use (1,1) for standard vector quantization
        embed_grad_update : if True, codebook is not updated with EMA,
                            but with gradients as in the original VQVAE paper.
        decay             : \gamme in EMA updates for the codebook

    """
    def __init__(self, dim, num_embeddings, num_codebooks=1, size=1, embed_grad_update=False,
                 decay=0.99, eps=1e-5) :
        super().__init__()

        self.i   = 0
        self.dim = dim
        self.eps = eps
        self.count = 1
        self.decay = decay
        self.egu  = embed_grad_update
        self.update_unused  = False
        self.num_codebooks  = num_codebooks
        self.num_embeddings = num_embeddings

        self.mix = 1

        # R = 1. / num_embeddings
        # self.embed = torch.randn(num_codebooks, num_embeddings, dim // num_codebooks).cpu()
        self.embed = torch.randn(num_codebooks, num_embeddings, dim // num_codebooks).cuda()
        self.embed = self.embed / torch.norm(self.embed, dim=2).t()
        # print(torch.norm(self.embed, dim=1))



    def forward(self, x):
        """
        Perform quantization op.

        Args:
            x (T)              : shape [B, C, H, W], where C = embeddings_dim
        Returns:
            quantize (T)       : shape [B, H, W, C], where C = embeddings_dim
            diff (float)       : commitment loss
            embed_ind          : codebook indices used in the quantization.
                                 this is what gets stored in the buffer
            perplexity (float) : codebook perplexity
        """

        # B, C, H, W = x.size()
        N, K, D = self.embed.size()

        # import pdb
        # assert C == N * D, pdb.set_trace()

        # B,N,D,H,W --> N, B, H, W, D
        # x_og = x
        # x = x.view(B, N, D, H, W).permute(1, 0, 3, 4, 2)
        
        #print(x.size())
        # N, B, H, W, D --> N, BHW, D
        # x_flat = x.detach().reshape(N, -1, D).cpu()
        # x_flat = x.detach().reshape(N, -1, D)
        #print(x_flat.size())

        # print(x_flat.device)
        # print(self.embed.device)
        # distances = torch.baddbmm(torch.sum(self.embed ** 2, dim=2).unsqueeze(1) +
                          #torch.sum(x.detach().reshape(N, -1, D) ** 2, dim=2, keepdim=True),
                          #x.detach().reshape(N, -1, D), self.embed.transpose(1, 2),
                          #alpha=-2.0, beta=1.0)

        # indices   = torch.argmin(distances, dim=-1)
        # print(x_flat.size())
        # print(distances.size())
        # print(indices.size())
        # embed_ind = indices.view(N, B, H, W).transpose(1,0)
        # embed_ind = indices.view_as(x)
        
        # if indices.max() >= K: pdb.set_trace()

        # encodings = F.one_hot(indices, K).bool()
        # print(encodings.type())
        # encodings = F.one_hot(indices, K).float()
        # print(self.embed.size())
        # print(F.one_hot(torch.argmin(distances, dim=-1), K).bool().size())
        quantized = torch.gather(self.embed, 1, torch.argmin(torch.baddbmm(torch.sum(self.embed ** 2, dim=2).unsqueeze(1) +
                          torch.sum(x.detach().reshape(N, -1, D) ** 2, dim=2, keepdim=True),
                          x.detach().reshape(N, -1, D), self.embed.transpose(1, 2),
                          alpha=-2.0, beta=1.0), dim=-1).unsqueeze(-1).expand(-1, -1, D))
        quantized = quantized.view_as(x)

        # print(quantized.size())
        
        if self.training and not self.egu and 0==1:
            self.i += 1

            self.mix *= 0.9999

            # EMA codebook update
            self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=1)

            n = torch.sum(self.ema_count, dim=-1, keepdim=True)
            self.ema_count = (self.ema_count + self.eps) / (n + K * self.eps) * n

            dw = torch.bmm(encodings.transpose(1, 2), x_flat)
            self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * dw

            self.embed = self.ema_weight / self.ema_count.unsqueeze(-1)

            if self.i > 10 and self.update_unused:
                unused = (self.ema_count < 1).nonzero()

                # reset unused vectors to random ones from the encoder batch
                unused_flat = unused[:, 0] * K + unused[:, 1]

                # get encodings
                enc_out = x_flat[unused[:, 0], torch.arange(unused.size(0))]

                ema_weight = self.ema_weight.view(-1, D)
                ema_weight[unused_flat] = enc_out

                self.ema_weight = ema_weight.view_as(self.ema_weight)
                self.ema_count[unused[:, 0], unused[:, 1]] = self.ema_count.mean()


        # diff = (quantized.detach() - x).pow(2).mean()

        if self.egu:
            # add vector quantization loss
            diff += (quantized - x.detach()).pow(2).mean()

        quantized = x + (quantized.cuda() - x).detach()

        # avg_probs = torch.mean(encodings, dim=1)
        # perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10), dim=-1))

        # quantized = quantized.permute(1, 0, 4, 2, 3).reshape(B, C, H, W)
        # diff      = diff.permute(1, 0, 4, 2, 3).reshape(B, C, H, W)

        return quantized#, embed_ind, diff



    def quantize(self, x):
        tr = self.training
        self.training = False
        z_q = self.forward(x)#[0]
        self.training = tr
        return z_q

class Quantize_gauss(nn.Module):

    """
    Quantization operation for VQ-VAE. Also supports Tensor Quantization

    Args:
        dim (int)         : dimensionality of each latent vector (D in paper)
        num_embeddings    : number of embedding in codebook (K in paper)
        size (int tuple)  : height and dim of each quantized tensor.
                            Use (1,1) for standard vector quantization
        embed_grad_update : if True, codebook is not updated with EMA,
                            but with gradients as in the original VQVAE paper.
        decay             : \gamme in EMA updates for the codebook

    """
    def __init__(self, args, dim, num_embeddings, num_codebooks=1, size=1, embed_grad_update=False,
                 decay=0.99, eps=1e-5) :
        super().__init__()

        self.i   = 0
        self.dim = dim
        self.eps = eps
        self.count = 1
        self.decay = decay
        self.egu  = embed_grad_update
        self.update_unused  = False
        self.num_codebooks  = num_codebooks
        self.num_embeddings = num_embeddings

        self.mix = 1
        
        self.device = args.device


        self.embed = torch.randn(num_codebooks, num_embeddings, dim // num_codebooks).to(args.device)  ###no normalization


    def forward(self, x):
        """
        Perform quantization op.

        Args:
            x (T)              : shape [B, C, H, W], where C = embeddings_dim
        Returns:
            quantize (T)       : shape [B, H, W, C], where C = embeddings_dim
            diff (float)       : commitment loss
            embed_ind          : codebook indices used in the quantization.
                                 this is what gets stored in the buffer
            perplexity (float) : codebook perplexity
        """

        # B, C, H, W = x.size()
        N, K, D = self.embed.size()

        # import pdb

        quantized = torch.gather(self.embed, 1, torch.argmin(torch.baddbmm(torch.sum(self.embed ** 2, dim=2).unsqueeze(1) +
                          torch.sum(x.detach().reshape(N, -1, D) ** 2, dim=2, keepdim=True),
                          x.detach().reshape(N, -1, D), self.embed.transpose(1, 2),
                          alpha=-2.0, beta=1.0), dim=-1).unsqueeze(-1).expand(-1, -1, D))
        quantized = quantized.view_as(x)


        quantized = x + (quantized.to(self.device) - x).detach()

        return quantized#, embed_ind, diff

    def quantize(self, x):
        tr = self.training
        self.training = False
        z_q = self.forward(x)#[0]
        self.training = tr
        return z_q
    
    
    
class Quantize_gauss_delaunay(nn.Module):

    def __init__(self, dim, num_embeddings, num_codebooks=1, size=1, embed_grad_update=False,
                 decay=0.99, eps=1e-5) :
        super().__init__()

        self.i   = 0
        self.dim = dim
        self.eps = eps
        self.count = 1
        self.decay = decay
        self.egu  = embed_grad_update
        self.update_unused  = False
        self.num_codebooks  = num_codebooks
        self.num_embeddings = num_embeddings

        self.mix = 1


        self.embed = torch.randn(num_codebooks, num_embeddings, dim // num_codebooks).cuda()  ###no normalization
        self.embed = torch.sort(self.embed, 1)[0]


    def forward(self, x):
        x = torch.clamp(x, min=torch.min(self.embed)+1e-3, max=torch.max(self.embed)-1e-3)
        # B, C, H, W = x.size()
        N, K, D = self.embed.size()

        amin = torch.argmin(torch.baddbmm(torch.sum(self.embed ** 2, dim=2).unsqueeze(1) +
                          torch.sum(x.detach().reshape(N, -1, D) ** 2, dim=2, keepdim=True),
                          x.detach().reshape(N, -1, D), self.embed.transpose(1, 2),
                          alpha=-2.0, beta=1.0), dim=-1).unsqueeze(-1).expand(-1, -1, D)
        bmin = amin + 2*(x.view_as(torch.gather(self.embed, 1, amin)) > torch.gather(self.embed, 1, amin)).long()-1
        dist1 = torch.norm(x.view_as(torch.gather(self.embed, 1, amin))- torch.gather(self.embed, 1, amin), dim=0, keepdim=True)
        distot = torch.norm(torch.gather(self.embed, 1, bmin)- torch.gather(self.embed, 1, amin), dim=0, keepdim=True)
        proba = dist1 / distot
        #print(proba.device)
        ber = torch.bernoulli(1-proba)
        choose = amin*ber + (1-ber)*bmin
        
        quantized = torch.gather(self.embed, 1, choose.long())
        #quantized = torch.gather(self.embed, 1, torch.tensor(choose, dtype=torch.int64).cuda())
        quantized = quantized.view_as(x)


        quantized = x + (quantized.cuda() - x).detach()

        return quantized#, embed_ind, diff

    def quantize(self, x):
        tr = self.training
        self.training = False
        z_q = self.forward(x)#[0]
        self.training = tr
        return z_q
    
    
    
class QSGD_rotation(nn.Module):

    def __init__(self, dim, num_embeddings, num_codebooks=1, size=1, embed_grad_update=False,
                 decay=0.99, eps=1e-5) :
        super().__init__()

        self.i   = 0
        self.dim = dim
        self.eps = eps
        self.count = 1
        self.decay = decay
        self.egu  = embed_grad_update
        self.update_unused  = False
        self.num_codebooks  = num_codebooks
        self.num_embeddings = num_embeddings
    
    
    def quantize_QSGD(self, x,s):
        x_norm=torch.norm(x, dim=0)
        sgn_x=torch.sign(x)
        p=torch.div(torch.abs(x),x_norm)
        renormalize_p=torch.mul(p,s)
        floor_p=torch.floor(renormalize_p)
        compare=torch.rand_like(floor_p)
        final_p=renormalize_p-floor_p
        margin=(compare < final_p)#.float()
        xi=(floor_p+margin)/s
        return x_norm*sgn_x*xi
    
    def quantize(self, x):
        tr = self.training
        self.training = False
        
        u, s, v = torch.svd(torch.randn(self.dim, self.dim))
        if torch.min(torch.abs(s)) < 1e-6:
            print('SVD Error')

        z_q = torch.matmul(v, self.quantize_QSGD(torch.matmul(u, x.detach().reshape(self.dim, -1)), torch.floor(torch.sqrt(torch.tensor(self.dim).type(torch.float))))).view_as(x)
        self.training = tr
        print(x.device)
        print(z_q.device)
        return z_q
