import torch
from .base import BaseLosses


class MVQLosses(BaseLosses):
    def __init__(self, cfg, stage, **kwargs):
        # Save parameters
        self.stage = stage

        # Define losses
        losses_params = {
            "recons_feature": cfg.LOSS.LAMBDA_FEATURE,
            "recons_velocity": cfg.LOSS.LAMBDA_VELOCITY,
            "recons_joints": cfg.LOSS.LAMBDA_JOINT,
            "identity_commit": cfg.LOSS.LAMBDA_COMMIT,
            "recons_jerk": cfg.LOSS.LAMBDA_JERK,
            "identity_perplexity": cfg.LOSS.LAMBDA_IDENTITY,
            "adversarial_g": cfg.LOSS.LAMBDA_ADV_G,
            "adversarial_d": cfg.LOSS.LAMBDA_ADV_D,
        }
        losses = list(losses_params.keys())

        # Define loss functions & weights
        losses_func = self._get_loss_func(
            losses, recons_loss=cfg.LOSS.ABLATION.RECONS_LOSS
        )

        super().__init__(cfg, losses, losses_params, losses_func, **kwargs)

    def update(self, rs_set, stage: str = "generator") -> float:
        """Update the losses"""
        total: float = 0.0

        # Define the real and fake labels
        if "D_fake" in rs_set and rs_set["D_fake"] is not None:
            real_labels = torch.ones_like(rs_set["D_real"])
            fake_labels = torch.zeros_like(rs_set["D_fake"])

        # Update the reconstruction losses
        if stage == "generator":
            total += self._update_loss(
                "recons_feature", rs_set["m_rst"], rs_set["m_ref"]
            )

            total += self._update_loss(
                "recons_joints", rs_set["joints_rst"], rs_set["joints_ref"]
            )

            nfeats = rs_set["m_rst"].shape[-1]

            if nfeats in [263, 251, 135 + 263]:
                if nfeats == 135 + 263:
                    vel_start = 135 + 4
                elif nfeats in [263, 251]:
                    vel_start = 4

                vel_rst = rs_set["m_rst"][
                    ..., vel_start : (self.num_joints - 1) * 3 + vel_start
                ]
                vel_ref = rs_set["m_ref"][
                    ..., vel_start : (self.num_joints - 1) * 3 + vel_start
                ]
                total += self._update_loss(
                    "recons_velocity",
                    vel_rst,
                    vel_ref,
                )

                # Update the transition jerk loss
                def get_jerk(joints):
                    vel = joints[:, 1:] - joints[:, :-1]  # --> ... x N-1 x 22 x 3
                    acc = vel[:, 1:] - vel[:, :-1]  # --> ... x N-2 x 22 x 3
                    jerk = acc[:, 1:] - acc[:, :-1]  # --> ... x N-3 x 22 x 3
                    jerk[:, -3:] = 0.0
                    return jerk

                if self._params["recons_jerk"] != 0.0:
                    total += self._update_loss(
                        "recons_jerk",
                        get_jerk(rs_set["joints_trans_rst"]),
                        get_jerk(rs_set["joints_trans_ref"]),
                    )

            elif nfeats == 135:

                def get_vel(joints):
                    vel = joints[:, 1:] - joints[:, :-1]
                    return vel

                if self._params["recons_velocity"] != 0.0:
                    total += self._update_loss(
                        "recons_velocity",
                        get_vel(rs_set["joints_rst"]),
                        get_vel(rs_set["joints_ref"]),
                    )

            else:
                if self._params["recons_velocity"] != 0.0:
                    raise NotImplementedError(
                        "Velocity not implemented for nfeats = {})".format(nfeats)
                    )

            # Update the commit loss
            total += self._update_loss("identity_commit", rs_set["loss_commit"])

            # Update perplexity
            self._update_loss("identity_perplexity", rs_set["perplexity"])

            # Update the adversarial loss
            # for x_fake in rs_set["D_fake"]:
            #   loss_g += torch.mean((1 - x_fake[-1]) ** 2)
            if "D_fake" in rs_set and rs_set["D_fake"] is not None:
                total += self._update_loss(
                    "adversarial_g", rs_set["D_fake"], real_labels
                )

        elif stage == "discriminator":
            if rs_set["D_fake"] is not None:
                total += self._update_loss(
                    "adversarial_d", rs_set["D_real"], real_labels
                )
                total += self._update_loss(
                    "adversarial_d", rs_set["D_fake"], fake_labels
                )
            # for x_fake, x_real in zip(rs_set["D_fake"], rs_set["D_real"]):
            #   total += torch.mean(x_fake[-1] ** 2)
            #   total += torch.mean((1 - x_real[-1]) ** 2)
        else:
            raise NotImplementedError(f"Stage {stage} not implemented.")

        # Update the total loss
        self.total += total.detach()
        self.count += 1

        return total
