import math
import torch
import torch.nn as nn
import sys
from torch import nn, optim
import time
import numpy as np
import random
import os
import scipy.io as sio
from torch.nn import functional as F
from scipy.io import savemat
import torchvision

seed = 12
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

max_steps = 2000
lr_init = 1.0
codebook_size = 16384
codebook_dim = 8

class Dictionary(nn.Module):
    def __init__(self, codebook_size, codebook_dim):
        super(Dictionary, self).__init__()
        initial = torch.randn(codebook_size, codebook_dim)
        self.embedding = nn.Embedding(codebook_size, codebook_dim)
        self.embedding.weight.data.copy_(initial)
        self.embedding.weight.requires_grad = False
        self.codebook_size = codebook_size
        
        self.decay = 0.9
        self.register_buffer("embed_prob", torch.zeros(self.codebook_size))

    def calc_wasserstein_distance(self, z):
        codebook = self.embedding.weight

        N = z.size(0)
        D = z.size(1)
        codebook_size = self.codebook_size

        z_mean = z.mean(0)
        z_covariance = torch.mm((z - torch.mean(z, dim=0, keepdim=True)).t(), z - torch.mean(z, dim=0, keepdim=True))/N
        
        ### compute the mean and covariance of codebook vectors
        c = codebook
        c_mean = c.mean(0)
        c_covariance = torch.mm((c - torch.mean(c, dim=0, keepdim=True)).t(), c - torch.mean(c, dim=0, keepdim=True))/codebook_size

        ### calculation of part1
        part_mean =  torch.sum(torch.multiply(z_mean - c_mean, z_mean - c_mean))

        d_covariance = torch.mm(z_covariance, c_covariance)
        
        ### 1/2 d_covariance
        S, Q = torch.linalg.eigh(d_covariance)
        sqrt_S = torch.sqrt(torch.diag(F.relu(S)) + 1e-8)
        d_sqrt_covariance = torch.mm(torch.mm(Q, sqrt_S), Q.T)

        #############calculation of part2
        part_covariance = F.relu(torch.trace(z_covariance + c_covariance - 2.0 * d_sqrt_covariance))
        wasserstein_loss = torch.sqrt(part_mean + part_covariance + 1e-8)
        return wasserstein_loss

    def quantize(self, z):
        distance = torch.sum(z.detach().square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)
        distance.addmm_(z.detach(), self.embedding.weight.data.T, alpha=-2, beta=1)

        token = torch.argmin(distance, dim=1)
        embed = self.embedding(token)
        quant_error = (embed - z.detach()).square().sum(1).mean()
        onehot_probs = F.one_hot(token, self.codebook_size).type(z.dtype)
        avg_probs = torch.mean(onehot_probs, dim=0)

        self.embed_prob.mul_(self.decay).add_(avg_probs, alpha= 1 - self.decay)
        #dist_cat = torch.sum(z.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)
        #dist_cat.addmm_(z, self.embedding.weight.data.T, alpha=-2, beta=1)
                
        _, indices = distance.sort(dim=0)
        random_feat = z.detach()[indices[-1,:]]

        decay = torch.exp(-(self.embed_prob * self.codebook_size * 10)/(1-self.decay) - 1e-3).unsqueeze(1).repeat(1, codebook_dim)
        self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay
        return quant_error

    def calc_metrics(self, z):
        distance = torch.sum(z.detach().square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)
        distance.addmm_(z.detach(), self.embedding.weight.data.T, alpha=-2, beta=1)

        token = torch.argmin(distance, dim=1) 
        embed = self.embedding(token)

        quant_error = (embed - z.detach()).square().sum(1).mean()
        codebook_histogram = token.bincount(minlength=self.codebook_size).float()
        codebook_usage_counts = (codebook_histogram > 0).float().sum()
        codebook_utilization = codebook_usage_counts.item() / self.codebook_size

        avg_probs = codebook_histogram/codebook_histogram.sum(0)
        codebook_perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        wasserstein_distance = self.calc_wasserstein_distance(z)

        return quant_error, codebook_utilization, codebook_perplexity, wasserstein_distance


Dict = Dictionary(codebook_size, codebook_dim).cuda()
#optimizer = torch.optim.SGD(Dict.embedding.parameters(), lr=lr_init, momentum=0.9)

##### zero-steps
z = torch.randn(1000000, codebook_dim).cuda() + 4.0
quant_error, codebook_utilization, codebook_perplexity, wasserstein_distance = Dict.calc_metrics(z)

for step in range(1, max_steps+1):
    z = torch.randn(50000, codebook_dim).cuda() + 4.0
    quant_error = Dict.quantize(z)

    if step == 1 or step%10 == 0:
        print('train step:{}/{}, quant_error:{:.4f}'.format(step, max_steps, quant_error.item()))
    if step == 1 or step%100 == 0:
        z = torch.randn(1000000, codebook_dim).cuda() + 4.0
        quant_error, codebook_utilization, codebook_perplexity, wasserstein_distance = Dict.calc_metrics(z)

        print('eval step:{}/{}, quant_error:{:.4f}, codebook_utilization:{:.4f}, codebook_perplexity:{:.4f}, wasserstein_distance:{:.4f}'.format(step, max_steps, quant_error.item(), codebook_utilization, codebook_perplexity.item(), wasserstein_distance.item()))

