import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat

from taming.modules.discriminator.model import NLayerDiscriminator, MultiClass_NLayerDiscriminator, weights_init
from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import lsgan_d_loss, hinge_d_loss, vanilla_d_loss


def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
    assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
    loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
    loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
    loss_real = (weights * loss_real).sum() / weights.sum()
    loss_fake = (weights * loss_fake).sum() / weights.sum()
    d_loss = 0.5 * (loss_real + loss_fake)
    return d_loss

def adopt_weight(weight, global_step, threshold=0, value=0.):
    if global_step < threshold:
        weight = value
    return weight


def measure_perplexity(predicted_indices, n_embed):
    # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
    # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
    encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
    avg_probs = encodings.mean(0)
    perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
    cluster_use = torch.sum(avg_probs > 0)
    return perplexity, cluster_use

def l1(x, y):
    return torch.abs(x-y)


def l2(x, y):
    return torch.pow((x-y), 2)

class VQLPIPS(nn.Module):
    def __init__(self, codebook_weight=1.0, pixelloss_weight=1.0, pixel_loss="l1"):
        super().__init__()
        assert pixel_loss in ["l1", "l2"]
        self.codebook_weight = codebook_weight
        self.pixel_weight = pixelloss_weight

        if pixel_loss == "l1":
            self.pixel_loss = l1
        else:
            self.pixel_loss = l2
        
    def forward(self, codebook_loss, inputs, reconstructions, global_step, split="train", predicted_indices=None):
        rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())

        nll_loss = torch.mean(rec_loss)

        loss = nll_loss + self.codebook_weight * codebook_loss.mean()

        log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
               "{}/quant_loss".format(split): codebook_loss.detach().mean(),
               "{}/nll_loss".format(split): nll_loss.detach().mean(),
               "{}/rec_loss".format(split): rec_loss.detach().mean(),
               }
        
        if predicted_indices is not None:
            with torch.no_grad():
                perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
            log[f"{split}/perplexity"] = perplexity
            log[f"{split}/cluster_usage"] = cluster_usage
            
        return loss, log


class VQLPIPSJoint(nn.Module):
    def __init__(self, pixelloss_weight=1.0, pixel_loss="l1"):
        super().__init__()
        assert pixel_loss in ["l1", "l2"]
        self.pixel_weight = pixelloss_weight

        if pixel_loss == "l1":
            self.pixel_loss = l1
        else:
            self.pixel_loss = l2
        
    def forward(self, inputs, reconstructions, mask=None, global_step=0, split="train", predicted_indices=None):
        rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())

        nll_loss = torch.mean(rec_loss)
        
        loss = nll_loss

        log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
               "{}/nll_loss".format(split): nll_loss.detach().mean(),
               "{}/rec_loss".format(split): rec_loss.detach().mean(),
               }
        
        if predicted_indices is not None:
            with torch.no_grad():
                perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
            log[f"{split}/perplexity"] = perplexity
            log[f"{split}/cluster_usage"] = cluster_usage
            
        return loss, log
