import torch
import torch.nn as nn
from .base import BaseLosses


class CommitLoss(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()

    def forward(self, commit, commit2, **kwargs):
        return commit


class GPTLosses(BaseLosses):
    def __init__(self, cfg, stage, num_joints, rec=False, rec_only=False, **kwargs):
        # Save parameters
        self.stage = stage
        self.rec = rec
        self.rec_only = rec_only
        recons_loss = cfg.LOSS.ABLATION.RECONS_LOSS

        # Define losses
        losses = []
        params = {}
        if stage == "vae":
            losses.append("recons_feature")
            params["recons_feature"] = cfg.LOSS.LAMBDA_FEATURE

            losses.append("recons_velocity")
            params["recons_velocity"] = cfg.LOSS.LAMBDA_VELOCITY

            losses.append("vq_commit")
            params["vq_commit"] = cfg.LOSS.LAMBDA_COMMIT

        elif stage in ["lm_pretrain", "lm_instruct"]:
            losses.append("gpt_loss")
            params["gpt_loss"] = cfg.LOSS.LAMBDA_CLS

        if rec:
            losses.append("recons_feature")
            params["recons_feature"] = cfg.LOSS.LAMBDA_FEATURE

        # Define loss functions & weights
        losses_func = {}
        for loss in losses:
            if loss.split("_")[0] == "recons":
                if recons_loss == "l1":
                    losses_func[loss] = nn.L1Loss
                elif recons_loss == "l2":
                    losses_func[loss] = nn.MSELoss
                elif recons_loss == "l1_smooth":
                    losses_func[loss] = nn.SmoothL1Loss
            elif loss.split("_")[1] in ["commit", "loss", "gpt", "m2t2m", "t2m2t"]:
                losses_func[loss] = CommitLoss
            elif loss.split("_")[1] in ["cls", "lm"]:
                losses_func[loss] = nn.CrossEntropyLoss
            else:
                raise NotImplementedError(f"Loss {loss} not implemented.")

        super().__init__(cfg, losses, params, losses_func, num_joints, **kwargs)

    def update(self, rs_set):
        """Update the losses"""
        total: float = 0.0

        if self.stage in ["vae"]:
            total += self._update_loss(
                "recons_feature", rs_set["m_rst"], rs_set["m_ref"]
            )
            # total += self._update_loss("recons_joints", rs_set['joints_rst'], rs_set['joints_ref'])
            nfeats = rs_set["m_rst"].shape[-1]
            if nfeats in [263, 135 + 263]:
                if nfeats == 135 + 263:
                    vel_start = 135 + 4
                elif nfeats == 263:
                    vel_start = 4
                total += self._update_loss(
                    "recons_velocity",
                    rs_set["m_rst"][
                        ..., vel_start : (self.num_joints - 1) * 3 + vel_start
                    ],
                    rs_set["m_ref"][
                        ..., vel_start : (self.num_joints - 1) * 3 + vel_start
                    ],
                )
            else:
                if self._params["recons_velocity"] != 0.0:
                    raise NotImplementedError(
                        "Velocity not implemented for nfeats = {})".format(nfeats)
                    )
            total += self._update_loss(
                "vq_commit", rs_set["loss_commit"], rs_set["loss_commit"]
            )

        if self.stage in ["lm_pretrain", "lm_instruct"]:
            if "m_rst" in rs_set:
                total += self._update_loss(
                    "recons_feature", rs_set["m_rst"], rs_set["m_ref"]
                )
            if not (self.rec_only and "m_ref" in rs_set):
                total += self._update_loss(
                    "gpt_loss", rs_set["losses"]["pred"], rs_set["losses"]["pred"]
                )

        # Update the total loss
        self.total += total.detach()
        self.count += 1

        return total
