import os
import warnings

import torch
from mmcv.runner import BaseModule
from torch import nn

from ..builder import HEADS
from ..utils import Encoder


@HEADS.register_module()
class CAEHead(BaseModule):
    """Pretrain Head for CAE.

    Compute the align loss and the main loss. In addition, this head also
    generates the prediction target generated by dalle.

    Args:
        tokenizer_path (str): The path of the tokenizer.
        lambd (float): The weight for the align loss.
        init_cfg (dict, optional): Initialization config dict.
            Defaults to None.
    """

    def __init__(self,
                 tokenizer_path=None,
                 lambd=2,
                 init_cfg=None):
        super(CAEHead, self).__init__(init_cfg=init_cfg)
        self.tokenizer_path = tokenizer_path
        self.lambd = lambd
        self.encoder = self._load_encoder()
        self.loss_cross_entropy = nn.CrossEntropyLoss()
        self.loss_mse = nn.MSELoss()

    def _load_encoder(self) -> nn.Module:
        encoder = Encoder()
        if os.path.exists(self.tokenizer_path):
            state_dict = torch.load(self.tokenizer_path)
            encoder.load_state_dict(state_dict)
        else:
            warnings.warn(
                f'Do not find {self.tokenizer_path}, please download from'
                'https://download.openmmlab.com/mmselfsup/cae/dalle_encoder.pth'
            )
        return encoder

    @torch.no_grad()
    def _generate_target(self, img_target: torch.Tensor) -> torch.Tensor:
        logits = self.encoder(img_target)
        target = torch.argmax(logits, dim=1)
        return target.flatten(1)

    def forward(self, img_target: torch.Tensor, outputs: torch.Tensor,
                latent_pred: torch.Tensor, latent_target: torch.Tensor,
                mask: torch.Tensor) -> dict:
        losses = dict()
        target = self._generate_target(img_target)
        target = target[mask]
        loss_main = self.loss_cross_entropy(outputs, target)
        loss_align = self.loss_mse(latent_pred,
                                   latent_target.detach()) * self.lambd

        losses['loss'] = loss_main + loss_align
        losses['main'] = loss_main
        losses['align'] = loss_align

        return losses
