import collections
from typing import Dict

import torch
import torch.nn as nn
import torch.nn.functional as F

from .base_loss import Loss, DistillKL
from codebook import MidEncoder


class CodebookReconstructLoss(Loss):
    ReconstructLosses = {
        "mse": nn.MSELoss,
        "L1": nn.L1Loss,
        "smooth_L1": nn.SmoothL1Loss
    }

    def __init__(
        self,
        mid_encoder: MidEncoder,
        reconstruct_loss: str = "mse",
        T: float = 4
    ):
        super().__init__()
        self.kd_loss = DistillKL(T=T)
        self.reconstruct_fn = self.ReconstructLosses[reconstruct_loss]()
        self.mid_encoder = mid_encoder

    def loss_cls(self, pred: torch.Tensor, gt: torch.LongTensor, model_label: torch.LongTensor):
        ret = collections.OrderedDict()
        ce_loss = F.cross_entropy(pred, gt)
        ce_loss_model = F.cross_entropy(pred, model_label)
        ret["cls"] = ce_loss
        ret["cls_model"] = ce_loss_model
        return ret

    def loss_origin_model(self, pred: torch.Tensor, model_pred: torch.Tensor):
        ret = collections.OrderedDict()
        kd_loss = self.kd_loss(pred, model_pred)
        ret["kd"] = kd_loss
        return ret

    def forward(
        self,
        output: Dict[str, torch.Tensor],
        output_origin: Dict[str, torch.Tensor],
        target: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:
        ret: Dict[str, torch.Tensor] = collections.OrderedDict()
        # learn from origin model
        pred = output["pred"]
        model_pred = output_origin["pred"]
        gt = target["label"]
        model_label = model_pred.argmax(dim=1)
        # losses
        ret.update(self.loss_cls(pred, gt, model_label))
        ret.update(self.loss_origin_model(pred, model_pred))
        ret["reconstruct"] = self.reconstruct_fn(
            self.mid_encoder.mid_dict["encoded_seq"],
            self.mid_encoder.mid_dict["origin_seq"]
        )
        return ret
