import os
import warnings
from typing import List, Optional, Union

import torch
from mmcv.runner import BaseModule
from torch import nn

from ..builder import HEADS
from ..utils import DALLEncoder


@HEADS.register_module()
class BEiTHead(BaseModule):
    """Pretrain Head for BEiT.

    Compute the cross entropy loss. In addition, this head also
    generates the prediction target generated by dalle.

    Args:
        tokenizer_path (str): The path of the DALLE tokenizer.
        init_cfg (dict or List[dict], optional): Initialization config dict.
            Defaults to None.
    """

    def __init__(self,
                 tokenizer_path=None,
                 init_cfg: Optional[Union[dict, List[dict]]] = None):
        super().__init__(init_cfg=init_cfg)
        self.tokenizer_path = tokenizer_path
        self.encoder = self._load_encoder()
        self.loss_cross_entropy = nn.CrossEntropyLoss()

    def _load_encoder(self) -> nn.Module:
        encoder = DALLEncoder()
        if os.path.exists(self.tokenizer_path):
            state_dict = torch.load(self.tokenizer_path)
            encoder.load_state_dict(state_dict)
        else:
            warnings.warn(
                f'Do not find {self.tokenizer_path}, please download from'
                'https://download.openmmlab.com/mmselfsup/cae/dalle_encoder.pth'
            )
        return encoder

    @torch.no_grad()
    def _generate_target(self, img_target: torch.Tensor) -> torch.Tensor:
        """Generate the reconstruction target."""
        logits = self.encoder(img_target)
        target = torch.argmax(logits, dim=1)
        return target.flatten(1)

    def forward(self,
                logits: torch.Tensor,
                img_target: torch.Tensor,
                mask: torch.Tensor,
                return_all_tokens=False) -> torch.Tensor:
        """ Classification loss for BEiT.

        Args:
            logits (torch.Tensor): Logits generated by decoder.
            img_target (img_target): Target generated by dalle for decoder
                prediction.
        """
        losses = dict()
        target = self._generate_target(img_target)  # target features
        target = target.detach()

        if not return_all_tokens:
            mask = mask.flatten(0).to(torch.bool)
            target = target.view(-1, 1)
            logits, target = logits[mask], target[mask]

        losses['loss'] = self.loss_cross_entropy(logits, target.squeeze(-1))

        return losses


@HEADS.register_module()
class CAEHead(BaseModule):
    """Pretrain Head for CAE.

    Compute the align loss and the main loss. In addition, this head also
    generates the prediction target generated by dalle.

    Args:
        tokenizer_path (str): The path of the DALLE tokenizer.
        lambd (float): The weight for the align loss.
        init_cfg (dict, optional): Initialization config dict.
            Defaults to None.
    """

    def __init__(self,
                 tokenizer_path=None,
                 lambd=2,
                 init_cfg: Optional[Union[dict, List[dict]]] = None):
        super(CAEHead, self).__init__(init_cfg=init_cfg)
        self.tokenizer_path = tokenizer_path
        self.lambd = lambd
        self.encoder = self._load_encoder()
        self.loss_cross_entropy = nn.CrossEntropyLoss()
        self.loss_mse = nn.MSELoss()

    def _load_encoder(self) -> nn.Module:
        encoder = DALLEncoder()
        if os.path.exists(self.tokenizer_path):
            state_dict = torch.load(self.tokenizer_path)
            encoder.load_state_dict(state_dict)
        else:
            warnings.warn(
                f'Do not find {self.tokenizer_path}, please download from'
                'https://download.openmmlab.com/mmselfsup/cae/dalle_encoder.pth'
            )
        return encoder

    @torch.no_grad()
    def _generate_target(self, img_target: torch.Tensor) -> torch.Tensor:
        logits = self.encoder(img_target)
        target = torch.argmax(logits, dim=1)
        return target.flatten(1)

    def forward(self, img_target: torch.Tensor, outputs: torch.Tensor,
                latent_pred: torch.Tensor, latent_target: torch.Tensor,
                mask: torch.Tensor) -> dict:
        """ Classification and MSE losses for CAE.

        Args:
            logits (torch.Tensor): Logits generated by decoder.
            img_target (img_target): Target generated by dalle for decoder
                prediction.
            latent_pred (torch.Tensor): Latent prediction by regressor.
            latent_target (torch.Tensor): Target for latent prediction,
                generated by teacher.
        """
        losses = dict()
        target = self._generate_target(img_target)
        target = target[mask]
        loss_main = self.loss_cross_entropy(outputs, target)
        loss_align = self.loss_mse(latent_pred,
                                   latent_target.detach()) * self.lambd

        losses['loss'] = loss_main + loss_align
        losses['main'] = loss_main
        losses['align'] = loss_align

        return losses
