import torch
import torch.nn as nn
import torch.nn.functional as F

from tokenizer.lpips import LPIPS
from .clip import clip
import timm
from transformers import SiglipImageProcessor, SiglipVisionConfig
from transformers.models.siglip.modeling_siglip import SiglipVisionModel

def hinge_d_loss(logits_real, logits_fake):
    loss_real = torch.mean(F.relu(1. - logits_real))
    loss_fake = torch.mean(F.relu(1. + logits_fake))
    d_loss = 0.5 * (loss_real + loss_fake)
    return d_loss


def vanilla_d_loss(logits_real, logits_fake):
    loss_real = torch.mean(F.softplus(-logits_real))
    loss_fake = torch.mean(F.softplus(logits_fake))
    d_loss = 0.5 * (loss_real + loss_fake)
    return d_loss


def non_saturating_d_loss(logits_real, logits_fake):
    loss_real = torch.mean(F.binary_cross_entropy_with_logits(torch.ones_like(logits_real),  logits_real))
    loss_fake = torch.mean(F.binary_cross_entropy_with_logits(torch.zeros_like(logits_fake), logits_fake))
    d_loss = 0.5 * (loss_real + loss_fake)
    return d_loss


def hinge_gen_loss(logit_fake):
    return -torch.mean(logit_fake)


def non_saturating_gen_loss(logit_fake):
    return torch.mean(F.binary_cross_entropy_with_logits(torch.ones_like(logit_fake),  logit_fake))


def adopt_weight(weight, global_step, threshold=0, value=0.):
    if global_step < threshold:
        weight = value
    return weight


class VQLoss(nn.Module):
    def __init__(self, disc_start, disc_loss="hinge", disc_dim=64, disc_type='patchgan', image_size=256,
                 disc_num_layers=3, disc_in_channels=3, disc_weight=1.0, disc_adaptive_weight = False,
                 gen_adv_loss='hinge', reconstruction_loss='l2', reconstruction_weight=1.0, 
                 codebook_weight=1.0, perceptual_weight=1.0, teacher="clipb_224"
    ):
        super().__init__()
        # discriminator loss
        assert disc_type in ["patchgan", "stylegan"]
        assert disc_loss in ["hinge", "vanilla", "non-saturating"]
        if disc_type == "patchgan":
            self.discriminator = PatchGANDiscriminator(
                input_nc=disc_in_channels, 
                n_layers=disc_num_layers,
                ndf=disc_dim,
            )
        elif disc_type == "stylegan":
            self.discriminator = StyleGANDiscriminator(
                input_nc=disc_in_channels, 
                image_size=image_size,
            )
        else:
            raise ValueError(f"Unknown GAN discriminator type '{disc_type}'.")
        if disc_loss == "hinge":
            self.disc_loss = hinge_d_loss
        elif disc_loss == "vanilla":
            self.disc_loss = vanilla_d_loss
        elif disc_loss == "non-saturating":
            self.disc_loss = non_saturating_d_loss
        else:
            raise ValueError(f"Unknown GAN discriminator loss '{disc_loss}'.")
        self.discriminator_iter_start = disc_start
        self.disc_weight = disc_weight
        self.disc_adaptive_weight = disc_adaptive_weight

        assert gen_adv_loss in ["hinge", "non-saturating"]
        # gen_adv_loss
        if gen_adv_loss == "hinge":
            self.gen_adv_loss = hinge_gen_loss
        elif gen_adv_loss == "non-saturating":
            self.gen_adv_loss = non_saturating_gen_loss
        else:
            raise ValueError(f"Unknown GAN generator loss '{gen_adv_loss}'.")

        # perceptual loss
        self.perceptual_loss = LPIPS().eval()
        self.perceptual_weight = perceptual_weight

        # reconstruction loss
        if reconstruction_loss == "l1":
            self.rec_loss = F.l1_loss
        elif reconstruction_loss == "l2":
            self.rec_loss = F.mse_loss
        else:
            raise ValueError(f"Unknown rec loss '{reconstruction_loss}'.")
        self.rec_weight = reconstruction_weight

        # codebook loss
        self.codebook_weight = codebook_weight
        self.teacher = teacher
        if self.teacher in ['clipb_224', 'vitamin_xlarge_256']:
            self.scaling_layer = ScalingLayerForClip()
        elif self.teacher == 'siglip_384':
            self.scaling_layer = ScalingLayerForSigLip()
        else:
            raise ValueError(f'teacher {self.teacher} not supported')
        

        # option: ["clipb_224", "vitamin_xlarge_256", "siglip_384"]
        if self.teacher == 'clipb_224':
            print(f"Use clipb as teacher model.")
            self.teacher_model, _ = clip.load("ViT-B/16", device='cpu', jit=False, 
                                          download_root='./clip_model')
        elif self.teacher == 'vitamin_xlarge_256':
            print(f"Use vitamin-xlarge-256 as teacher model.")
            self.teacher_model = timm.create_model('vitamin_xlarge_256', pretrained=True)
        elif self.teacher == 'siglip_384':
            print(f"Use siglip-so400m-patch14-384 as teacher model.")
            self.teacher_model = SiglipVisionModel.from_pretrained("google/siglip-so400m-patch14-384")
        else:
            raise NotImplementedError(f"Unknown pretrain clip '{self.teacher}'.")
        
        for param in self.teacher_model.parameters():
            param.requires_grad = False # frozen teacher_model model
        self.teacher_model.eval()

    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
        nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
        g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]

        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
        return d_weight.detach()

    @torch.no_grad()
    def get_regress_target(self, x, **kwargs):
        norm_imgs = self.scaling_layer(x) # TODO
        if self.teacher == 'clipb_224':
            target = self.teacher_model.encode_image(norm_imgs, return_all_tokens=True) @ self.teacher_model.visual.proj
        elif self.teacher == 'vitamin_xlarge_256':
            target = self.teacher_model.forward_features(norm_imgs)
        elif self.teacher == 'siglip_384':
            target = self.teacher_model(norm_imgs, output_hidden_states=True).hidden_states[-2]
        else:
            raise NotImplementedError(f"Unknown pretrain clip '{self.teacher}'.")

        return target

    def calculate_clip_rec_loss(self, rec, target):  
        target = target / target.norm(dim=-1, keepdim=True)
        rec = rec / rec.norm(dim=-1, keepdim=True)
        rec_loss = (1 - (target * rec).sum(-1)).mean()

        return rec_loss

    def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, global_step, last_layer=None, 
                logger=None, log_every=100):
        if isinstance(reconstructions, tuple):
            vqkd_recon, reconstructions = reconstructions
        else:
            vqkd_recon = None
        
        # generator update
        if optimizer_idx == 0:
            if vqkd_recon is not None:
                clip_target = self.get_regress_target(inputs)
                vqkd_loss = self.calculate_clip_rec_loss(vqkd_recon, clip_target)
            else:
                vqkd_loss = torch.tensor([0.]).mean().to(inputs)
            # reconstruction loss
            rec_loss = self.rec_loss(inputs.contiguous(), reconstructions.contiguous())

            # perceptual loss
            p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
            p_loss = torch.mean(p_loss)

            # discriminator loss
            logits_fake = self.discriminator(reconstructions.contiguous())
            generator_adv_loss = self.gen_adv_loss(logits_fake)
            
            if self.disc_adaptive_weight:
                null_loss = self.rec_weight * rec_loss + self.perceptual_weight * p_loss
                disc_adaptive_weight = self.calculate_adaptive_weight(null_loss, generator_adv_loss, last_layer=last_layer)
            else:
                disc_adaptive_weight = 1
            disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start)
            
            loss = self.rec_weight * rec_loss + \
                self.perceptual_weight * p_loss + \
                disc_adaptive_weight * disc_weight * generator_adv_loss + \
                codebook_loss[0] + codebook_loss[1] + codebook_loss[2] + \
                vqkd_loss
            
            if global_step % log_every == 0:
                rec_loss = self.rec_weight * rec_loss
                p_loss = self.perceptual_weight * p_loss
                generator_adv_loss = disc_adaptive_weight * disc_weight * generator_adv_loss
                logger.info(f"(Generator) rec_loss: {rec_loss:.4f}, perceptual_loss: {p_loss:.4f}, "
                            f"vq_loss: {codebook_loss[0]:.4f}, commit_loss: {codebook_loss[1]:.4f}, entropy_loss: {codebook_loss[2]:.4f}, "
                            f"codebook_usage: {codebook_loss[3]:.4f}, generator_adv_loss: {generator_adv_loss:.4f}, "
                            f"disc_adaptive_weight: {disc_adaptive_weight:.4f}, disc_weight: {disc_weight:.4f}, " 
                            f"vqkd_loss: {vqkd_loss:.4f}, d_vqkd: {codebook_loss[4]:.4f}, d_vqgan: {codebook_loss[5]:.4f}")
            loss_dict = {
                'rec_loss': rec_loss,
                'p_loss': p_loss,
                'vq_loss': codebook_loss[0],
                'commit_loss': codebook_loss[1],
                'entropy_loss': codebook_loss[2],
                'vqkd_loss':vqkd_loss,
                'codebook_usage': codebook_loss[3],
                'generator_adv_loss': generator_adv_loss,
                'disc_adaptive_weight': disc_adaptive_weight,
                'disc_weight': disc_weight,
                'd_vqkd': codebook_loss[4],
                'd_vqgan': codebook_loss[5],
            }
            return loss, loss_dict

        # discriminator update
        if optimizer_idx == 1:
            logits_real = self.discriminator(inputs.contiguous().detach())
            logits_fake = self.discriminator(reconstructions.contiguous().detach())

            disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start)
            d_adversarial_loss = disc_weight * self.disc_loss(logits_real, logits_fake)
            
            if global_step % log_every == 0:
                logits_real = logits_real.detach().mean()
                logits_fake = logits_fake.detach().mean()
                logger.info(f"(Discriminator) " 
                            f"discriminator_adv_loss: {d_adversarial_loss:.4f}, disc_weight: {disc_weight:.4f}, "
                            f"logits_real: {logits_real:.4f}, logits_fake: {logits_fake:.4f}")
            
            loss_dict = {
                'discriminator_adv_loss': d_adversarial_loss,
                'logits_real': logits_real,
                'logits_fake': logits_fake,
                'disc_weight': disc_weight,
            }
            return d_adversarial_loss, loss_dict


class ScalingLayerForClip(nn.Module):
    def __init__(self):
        super(ScalingLayerForClip, self).__init__()
        self.register_buffer('shift', torch.Tensor([0.48145466, 0.4578275, 0.40821073])[None, :, None, None])
        self.register_buffer('scale', torch.Tensor([0.26862954, 0.26130258, 0.27577711])[None, :, None, None])

    def forward(self, inp):
        inp = ((inp + 1.) * 127.5).clamp(0, 255.) / 255. # rescale to [0, 1.]
        return (inp - self.shift) / self.scale

class ScalingLayerForSigLip(nn.Module):
    def __init__(self):
        super(ScalingLayerForSigLip, self).__init__()
        self.register_buffer('shift', torch.Tensor([0.5, 0.5, 0.5])[None, :, None, None])
        self.register_buffer('scale', torch.Tensor([0.5, 0.5, 0.5])[None, :, None, None])

    def forward(self, inp):
        inp = ((inp + 1.) * 127.5).clamp(0, 255.) / 255. # rescale to [0, 1.]
        return (inp - self.shift) / self.scale

class SegVQLoss(nn.Module):
    """
    Simple segmentation loss wrapper for VQ segmentation models.
    Expects a pre-calculated seg_loss (e.g., from Hungarian/bipartite set loss).
    """
    def __init__(self):
        super().__init__()
        # No trainable parameters

    def forward(self, seg_loss, *args, **kwargs):
        """
        seg_loss: scalar tensor, precomputed segmentation loss
        Returns:
            loss: scalar tensor
            loss_dict: dict with 'seg_loss'
        """
        return seg_loss, {'seg_loss': seg_loss}