from models.lie_ssl.shapes_model import SimCLRLieModule
import torch


class SimCLRLieFixedTModule(SimCLRLieModule):
    """Sets t = normalized delta.

    delta is normalized such that [-180, 180] -> [-1, 1]
    """

    def __init__(self, **kwargs):
        super().__init__(infer_t_bounds=False, **kwargs)
        self.t_bounds = torch.tensor([[1.0, -1.0] for i in range(self.dim_d)]).T

    def on_train_start(self) -> None:
        return None

    def normalize_delta(self, delta: torch.Tensor):
        """Delta is a tensor with shape(batch_size, 1)"""
        delta_normalized = (delta - 180.0) / 180.0
        t = torch.stack([delta_normalized] * self.dim_d).T.squeeze()
        return t

    def lie_loss(self, z1_aug, z2_aug, delta, t=None):

        t = self.normalize_delta(delta)
        z2_aug_hat, t, g = self.operate(z1_aug, z2_aug, delta, t)

        lie_loss = self.compute_lie_infonce(z2_aug, z2_aug_hat, z1_aug)

        euc_loss = ((z2_aug - z2_aug_hat) ** 2).sum(-1).mean(0)
        l2_loss = self.L2_constraint(None, None, delta, t)  # TODO not using any z yet

        return lie_loss, euc_loss, l2_loss, t, g

    def shared_step(self, batch, stage: str = "train", return_terms=False):

        x1tuple, x2tuple, _, _, _, delta = batch
        x1_aug1, x1_aug2, x1_online = x1tuple
        x2_aug1, x2_aug2, x2_online = x2tuple
        z1_aug1 = self.compute_rep(x1_aug1)
        z1_aug2 = self.compute_rep(x1_aug2)
        z2_aug1 = self.compute_rep(x2_aug1)
        z2_aug2 = self.compute_rep(x2_aug2)
        if "canonical" in stage:
            ssl_loss = 2 * self.compute_simclr_infonce(z1_aug1, z1_aug2)
            lie_loss = 0
            euc_loss = 0
            l2_loss = 0
            ind_constraint = 0
        else:
            ssl_loss = self.compute_simclr_infonce(
                z1_aug1, z1_aug2
            ) + self.compute_simclr_infonce(z2_aug1, z2_aug2)
            lie_loss, euc_loss, l2_loss, t, g = self.lie_loss(z1_aug1, z2_aug1, delta)

            # Workaround for online probing
            self.g_matrix[stage] = g.detach().clone()
            ind_constraint = self.ind_constraint()

        loss = (
            self.lambda_ssl * ssl_loss
            + self.lambda_lie * lie_loss
            + self.lambda_euc * euc_loss
            + self.lambda_l2 * l2_loss
            + self.lambda_i * ind_constraint
        )

        batch_size = x1_aug1.shape[0]
        self.log(
            f"{stage}_loss",
            loss,
            sync_dist=True,
            batch_size=batch_size,  # loader names are used instead
            add_dataloader_idx=False,
            on_step=True,
            on_epoch=True,
        )

        if return_terms:
            return loss, (ssl_loss, lie_loss, euc_loss, l2_loss, ind_constraint)
        else:
            return loss


class SimCLRLieNoLieInfoNCEModule(SimCLRLieModule):
    """Removes Lie InfoNCE and applies |g(z_1) - z_2| instead"""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def compute_lie_infonce(self, z1, z2, z3):
        return 0.0
