import math
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from quantizer import OTVectorQuantizer
import networks.mnist as net_mnist
import networks.fashion_mnist as net_fashionmnist
import networks.cifar10 as net_cifar10
import networks.celeba as net_celeba
import networks.celebamask_hq as net_celebamask_hq
from third_party.ive import ive


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)



class OTVAE(nn.Module):
    def __init__(self, cfgs, flgs):
        super(OTVAE, self).__init__()
        # Data space
        dataset = cfgs.dataset.name
        self.dim_x = cfgs.dataset.dim_x
        self.dataset = cfgs.dataset.name

        # Encoder/decoder
        self.encoder = eval("net_{}.EncoderVq_{}".format(dataset.lower(), cfgs.network.name))(
            cfgs.quantization.dim_dict, cfgs.network, flgs.bn)
        self.decoder = eval("net_{}.DecoderVq_{}".format(dataset.lower(), cfgs.network.name))(
            cfgs.quantization.dim_dict, cfgs.network, flgs.bn)
        self.apply(weights_init)

        # Codebook
        self.size_dict = cfgs.quantization.size_dict
        self.dim_dict = cfgs.quantization.dim_dict
        self.codebook = nn.Parameter(torch.randn(self.size_dict, self.dim_dict))
        
        # OT 
        self.eps = cfgs.quantization.eps
        self.ot_iter = cfgs.quantization.ot_iter
        self.eta = cfgs.loss.eta
        self.temp = nn.Parameter(torch.tensor(cfgs.quantization.temp))
        
        self.quantizer = OTVectorQuantizer(self.size_dict, self.dim_dict, self.eps, self.ot_iter)
        
    
    def forward(self, x, flg_train=False, flg_quant_det=True):
        
        z_from_encoder = F.normalize(self.encoder(x), p=2.0, dim=1)
        
        # Quantization
        z_quantized, loss_latent, perplexity = self.quantizer(z_from_encoder, self.codebook, self.temp, flg_train, flg_quant_det)
        latents = dict(z_from_encoder=z_from_encoder, z_to_decoder=z_quantized)

        # Decoding
        x_reconst = self.decoder(z_quantized)

        # Loss
        loss = self._calc_loss(x_reconst, x, loss_latent)
        loss["perplexity"] = perplexity
        
        return x_reconst, latents, loss
    
    def _calc_loss(self, x_reconst, x, loss_latent):
        
        bs = x.shape[0]
        
        # Reconstruction loss
        #mse = F.mse_loss(x_reconst, x, reduction="sum") / bs
        mse = F.mse_loss(x_reconst, x)
        
        loss_reconst = mse
        
        loss_all = loss_reconst + loss_latent * self.eta
        loss = dict(all=loss_all, mse=mse)

        return loss
    

class OTVAEMask(OTVAE):
    def __init__(self, cfgs, flgs):
        super(OTVAEMask, self).__init__(cfgs, flgs)
        self.__m = np.ceil(cfgs.network.num_class / 2)
        self.n_interval = cfgs.network.num_class - 1

    def _calc_loss(self, x_reconst, x, loss_latent):
        x_shape = x.shape
        # Reconstruction loss
        x = x.view(-1, 1)
        x_reconst_viewed = (x_reconst.permute(0, 2, 3, 1).contiguous()
                            .view(-1, int(self.__m * 2)) )
        x_reconst_normed = F.normalize(x_reconst_viewed, p=2.0, dim=-1)
        x_one_hot = (F.one_hot(x.to(torch.int).long(), num_classes = int(self.__m * 2))
                    .type_as(x))[:,0,:]
        x_reconst_selected = (x_one_hot * x_reconst_normed).sum(-1).view(x_shape)
        loss_reconst = - 1. * x_reconst_selected.mean()
        
        # Entire loss
        loss_all = loss_reconst + loss_latent * self.eta
        idx_estimated = torch.argmax(x_reconst_normed, dim=-1, keepdim=True)
        acc = torch.isclose(x.to(int), idx_estimated).sum() / idx_estimated.numel()
        loss = dict(all=loss_all, acc=acc)

        return loss


