import math
import torch
import torch.nn as nn
import sys
from torch import nn, optim
import time
import numpy as np
import random
import os
import scipy.io as sio
from torch.nn import functional as F
from scipy.io import savemat
import torchvision

seed = 12
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

max_steps = 2000
lr_init = 3.0
codebook_size = 16384
codebook_dim = 8
alpha = 0.1

class Dictionary(nn.Module):
    def __init__(self, codebook_size, codebook_dim):
        super(Dictionary, self).__init__()
        initial = torch.randn(codebook_size, codebook_dim)
        self.embedding = nn.Embedding(codebook_size, codebook_dim)
        self.embedding.weight.data.copy_(initial)
        self.embedding.weight.requires_grad = True
        self.codebook_size = codebook_size

    def calc_wasserstein_distance(self, z):
        codebook = self.embedding.weight

        N = z.size(0)
        D = z.size(1)
        codebook_size = self.codebook_size

        z_mean = z.mean(0)
        z_covariance = torch.mm((z - torch.mean(z, dim=0, keepdim=True)).t(), z - torch.mean(z, dim=0, keepdim=True))/N
        
        ### compute the mean and covariance of codebook vectors
        c = codebook
        c_mean = c.mean(0)
        c_covariance = torch.mm((c - torch.mean(c, dim=0, keepdim=True)).t(), c - torch.mean(c, dim=0, keepdim=True))/codebook_size

        ### calculation of part1
        part_mean =  torch.sum(torch.multiply(z_mean - c_mean, z_mean - c_mean))

        d_covariance = torch.mm(z_covariance, c_covariance)
        
        ### 1/2 d_covariance
        S, Q = torch.linalg.eigh(d_covariance)
        sqrt_S = torch.sqrt(torch.diag(F.relu(S)) + 1e-8)
        d_sqrt_covariance = torch.mm(torch.mm(Q, sqrt_S), Q.T)

        #############calculation of part2
        part_covariance = F.relu(torch.trace(z_covariance + c_covariance - 2.0 * d_sqrt_covariance))
        wasserstein_loss = torch.sqrt(part_mean + part_covariance + 1e-8)
        return wasserstein_loss

    def calc_commit_loss(self, z):
        distance = torch.sum(z.detach().square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)
        distance.addmm_(z.detach(), self.embedding.weight.data.T, alpha=-2, beta=1)

        token = torch.argmin(distance, dim=1) 
        embed = self.embedding(token)
        commit_loss = (embed - z.detach()).square().sum(1).mean()
        return commit_loss

    def calc_wasserstein_loss(self, z):
        ### compute the mean and covariance of feature
        N = z.size(0)
        D = z.size(1)
        codebook_size = self.codebook_size

        z_mean = z.mean(0)
        z_covariance = torch.mm((z - torch.mean(z, dim=0, keepdim=True)).t(), z - torch.mean(z, dim=0, keepdim=True))/N
        
        ### compute the mean and covariance of codebook vectors
        c = self.embedding.weight
        c_mean = c.mean(0)
        c_covariance = torch.mm((c - torch.mean(c, dim=0, keepdim=True)).t(), c - torch.mean(c, dim=0, keepdim=True))/codebook_size

        ### calculation of part1
        part_mean =  torch.sum(torch.multiply(z_mean - c_mean, z_mean - c_mean))

        d_covariance = torch.mm(z_covariance, c_covariance)
        
        ### 1/2 d_covariance
        S, Q = torch.linalg.eigh(d_covariance)
        sqrt_S = torch.sqrt(torch.diag(F.relu(S)) + 1e-8)
        d_sqrt_covariance = torch.mm(torch.mm(Q, sqrt_S), Q.T)

        #############calculation of part2
        part_covariance = F.relu(torch.trace(z_covariance + c_covariance - 2.0 * d_sqrt_covariance))
        wasserstein_loss = torch.sqrt(part_mean + part_covariance + 1e-8)
        return wasserstein_loss

    def calc_metrics(self, z):
        distance = torch.sum(z.detach().square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)
        distance.addmm_(z.detach(), self.embedding.weight.data.T, alpha=-2, beta=1)

        token = torch.argmin(distance, dim=1) 
        embed = self.embedding(token)

        quant_error = (embed - z.detach()).square().sum(1).mean()
        codebook_histogram = token.bincount(minlength=self.codebook_size).float()
        codebook_usage_counts = (codebook_histogram > 0).float().sum()
        codebook_utilization = codebook_usage_counts.item() / self.codebook_size

        avg_probs = codebook_histogram/codebook_histogram.sum(0)
        codebook_perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        wasserstein_distance = self.calc_wasserstein_distance(z)

        return quant_error, codebook_utilization, codebook_perplexity, wasserstein_distance


Dict = Dictionary(codebook_size, codebook_dim).cuda()
optimizer = torch.optim.SGD(Dict.embedding.parameters(), lr=lr_init, momentum=0.9)

##### zero-steps
z = torch.randn(1000000, codebook_dim).cuda() + 4.0
quant_error, codebook_utilization, codebook_perplexity, wasserstein_distance = Dict.calc_metrics(z)

for step in range(1, max_steps+1):
    z = torch.randn(50000, codebook_dim).cuda() + 4.0
    commit_loss = Dict.calc_commit_loss(z)
    wasserstein_loss = Dict.calc_wasserstein_loss(z)

    loss =  10 * wasserstein_loss
    optimizer.zero_grad() 
    loss.backward()
    optimizer.step()

    if step == 1 or step%10 == 0:
        print('train step:{}/{}, commit loss:{:.4f}, wasserstein_loss:{:.4f}'.format(step, max_steps, commit_loss.item(), wasserstein_loss.item()))
    if step == 1 or step%100 == 0:
        z = torch.randn(1000000, codebook_dim).cuda() + 4.0
        quant_error, codebook_utilization, codebook_perplexity, wasserstein_distance = Dict.calc_metrics(z)
        print('eval step:{}/{}, quant_error:{:.4f}, codebook_utilization:{:.4f}, codebook_perplexity:{:.4f}, wasserstein_distance:{:.4f}'.format(step, max_steps, quant_error.item(), codebook_utilization, codebook_perplexity.item(), wasserstein_distance.item()))










        

   