from models.base_model import ShapesBaseModel, LieModelMixins
from typing import Tuple, Optional
import pytorch_lightning as pl
import torch
from torch import Tensor
from models.mae import mae as mae_module
from models.simclr.simclr import Projection
import numpy as np


class MAELieModule(ShapesBaseModel, LieModelMixins):
    """Extends MAE pretraining objective to include Lie loss.

    Args:
        lambda_lie: factor multiplying lie loss
        lambda_ssl: factor multiplying original ssl loss for MAE
        lambda_euc: factor multiplying regularization for |g(z) - z'|
        dim_d: number of Lie generators
        ssl_temperature: temperature used in the Lie InfoNCE loss

    """

    def __init__(
        self,
        learning_rate: float = 1e-5,
        optimizer: str = "adam",
        momentum: float = 0.9,
        weight_decay: float = 1e-4,
        lambda_lie: float = 1.0,
        lambda_ssl: float = 1.0,
        lambda_euc: float = 1.0,
        ssl_temperature: float = 0.1,
        dim_d: int = 100,
        top_k: Tuple[int, ...] = (1, 10),
        datamodule: Optional[pl.LightningDataModule] = None,
        infer_t_bounds: bool = True,
        truncate_matrix_exp: bool = False,
    ):
        super().__init__(top_k=top_k, datamodule=datamodule)
        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.momentum = momentum
        self.weight_decay = weight_decay

        self.lambda_lie_loss = lambda_lie
        self.lambda_ssl_loss = lambda_ssl
        self.lambda_euc_loss = lambda_euc
        self.temperature = ssl_temperature
        self.truncate_matrix_exp = truncate_matrix_exp
        # number of lie generators
        self.dim_d = dim_d

        self.mae = self.load_pretained_mae()
        self.embedding_dim = 768
        # for comptability with attribute of online probe
        self.z_dim = self.embedding_dim

        # match SimCLR projection head
        self.projection = Projection(
            input_dim=self.embedding_dim, hidden_dim=self.embedding_dim, output_dim=128
        )

        # stage to g_matrix. This is reset for every training batch
        self.g_matrix = dict()
        self.L_basis = self.initialize_L_basis(self.dim_d, self.embedding_dim)

        self.t_inference_network = self.create_t_inference_network(
            self.dim_d, self.embedding_dim
        )
        if infer_t_bounds:
            # store only min and max
            self.t_bounds = torch.nn.parameter.Parameter(
                data=torch.zeros(2, self.dim_d), requires_grad=False
            )

    def load_pretained_mae(self):
        mae = mae_module.mae_vit_base_patch16()
        checkpoint = torch.hub.load_state_dict_from_url(
            "https://dl.fbaipublicfiles.com/mae/finetune/mae_finetuned_vit_base.pth"
        )
        mae.load_state_dict(checkpoint["model"], strict=False)
        return mae

    def forward(self, x) -> Tensor:
        """Generates an embedding z for the input.
        Args:
            x: batch_size, 3, 224, 224
        Returns: tensor of shape
            - [batch_size, 1, dim_m] if use_identity
            - [batch_size, num_neighbors, dim_m]
        """
        return self.compute_rep(x)

    def compute_rep(self, x) -> Tensor:
        latent_patches, _, _ = self.mae.forward_encoder(x, mask_ratio=0.0)
        # grab cls token for classification
        z = latent_patches[:, 0, :]
        return z

    def online_probe_forward(self, x1_online, x2_online, g, stage: str = "train"):
        z1_online = self.forward(x1_online)
        z2_online = self.forward(x2_online)
        # reps is [batch_size, 2, dim_m]
        reps = torch.stack([z1_online, z2_online], dim=1)

        if "canonical" not in stage:
            # g is [batch_size, dim_m, dim_m]
            # reps is [batch_size, 2, dim_m]
            g_reps = torch.einsum("bij,bkj->bik", g, reps)
            # g_reps is [batch_size, dim_m, 2]
            g_reps = g_reps.permute(0, 2, 1)
            # g_reps is now [batch_size, 2, dim_m]
            reps = torch.cat([reps, g_reps], dim=1)
            # reps is now [batch_size, 4, dim_m]
        return reps

    def mae_forward(self, x):
        latent, mask, ids_restore = self.mae.forward_encoder(x)
        pred = self.mae.forward_decoder(latent, ids_restore)
        loss = self.mae.forward_loss(x, pred, mask)
        return loss, latent

    def shared_step(self, batch, stage: str = "train", return_terms=False):

        x1tuple, x2tuple, _, _, _, delta = batch
        x1_aug1, x1_aug2, x1_online = x1tuple
        x2_aug1, x2_aug2, x2_online = x2tuple

        mae_loss1, _ = self.mae_forward(x1_aug1)
        mae_loss2, _ = self.mae_forward(x2_aug1)
        mae_loss = mae_loss1 + mae_loss2

        if "canonical" in stage:
            # skip lie
            lie_loss = torch.tensor(0.0)
            euc_loss = torch.tensor(0.0)
            l2_loss = torch.tensor(0.0)
            self.g_matrix[stage] = None
        else:
            z1, z2 = self.compute_rep(x1_aug1), self.compute_rep(x2_aug1)
            z2_hat, g, t = self.predict_z2_hat(z1, z2, delta)
            lie_loss = self.compute_lie_loss(z1, z2, z2_hat)
            euc_loss = ((z2 - z2_hat) ** 2).sum(-1).mean(0)
            l2_loss = self.l2_contraint(delta, t)

            # for online probing
            # no need to clone based on PyTorch docs
            self.g_matrix[stage] = g.detach()

        loss = (
            mae_loss * self.lambda_ssl_loss
            + lie_loss * self.lambda_lie_loss
            + euc_loss * self.lambda_euc_loss
            + l2_loss
        )

        batch_size = x1_aug1.shape[0]
        self.log_loss(f"{stage}_mae_loss", mae_loss, batch_size)
        self.log_loss(f"{stage}_lie_loss", lie_loss, batch_size)
        self.log_loss(f"{stage}_euc_loss", euc_loss, batch_size)
        self.log_loss(f"{stage}_l2_loss", l2_loss, batch_size)
        self.log_loss(f"{stage}_loss", loss, batch_size)

        return loss

    def l2_contraint(self, delta, t):
        sim = 1 / (1 + torch.exp(delta.abs()))
        loss = sim * (t ** 2).sum(-1)
        return loss.mean()

    def predict_z2_hat(self, z1, z2, delta) -> Tuple[Tensor, Tensor, Tensor]:
        t = self.infer_t(z1, z2, delta)
        lie_matrix = (t[..., None, None] * self.L_basis[None, ...]).sum(1)
        g = self.matrix_exponential(lie_matrix)

        z2_hat = torch.bmm(g, z1[..., None]).squeeze(-1)
        return z2_hat, g, t

    def transform(self, z, t=None) -> Tensor:
        """Transforms z by applying g(z, t)

        If t is None, then a randomly sampled t is generated.
        """
        if t is None:
            batch_size = z.shape[0]
            t = torch.Tensor(np.random.uniform(-1.0, 1.0, (batch_size, self.dim_d))).to(
                z.device
            )

        lie_matrix = (t[..., None, None] * self.L_basis[None, ...]).sum(1)
        g = self.matrix_exponential(lie_matrix)

        z_transformed = torch.bmm(g, z[..., None]).squeeze(-1)
        return z_transformed

    def compute_lie_loss(self, z1, z2, z2_hat) -> Tensor:
        lie_loss = self.lie_nt_xent_loss(
            self.projection(z2),
            self.projection(z2_hat),
            self.projection(z1),
            self.temperature,
        )
        return lie_loss

    def log_loss(self, name: str, value: Tensor, batch_size: int):
        self.log(
            name,
            value,
            sync_dist=True,
            batch_size=batch_size,  # loader names are used instead
            add_dataloader_idx=False,
            on_step=True,
            on_epoch=True,
        )


class MAELieMaskedModule(MAELieModule):
    """Implementes Lie operator on masked embedding"""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def forward(self, x) -> Tensor:
        full_latent, _, _ = self.mae.forward_encoder(x)
        cls_token = self.compute_rep(full_latent)
        return cls_token

    def compute_rep(self, latent) -> Tensor:
        return latent[:, 0, :]

    def shared_step(self, batch, stage: str = "train", return_terms=False):
        x1tuple, x2tuple, _, _, _, delta = batch
        x1_aug1, x1_aug2, x1_online = x1tuple
        x2_aug1, x2_aug2, x2_online = x2tuple

        mae_loss1, latent1 = self.mae_forward(x1_aug1)
        mae_loss2, latent2 = self.mae_forward(x2_aug1)
        mae_loss = mae_loss1 + mae_loss2

        if "canonical" in stage:
            # skip lie
            lie_loss = torch.tensor(0.0)
            euc_loss = torch.tensor(0.0)
            self.g_matrix[stage] = None
        else:
            z1, z2 = self.compute_rep(latent1), self.compute_rep(latent2)
            z2_hat, g, t = self.predict_z2_hat(z1, z2, delta)
            lie_loss = self.compute_lie_loss(z1, z2, z2_hat)
            euc_loss = ((z2 - z2_hat) ** 2).sum(-1).mean(0)

            # for online probing
            # no need to clone based on PyTorch docs
            self.g_matrix[stage] = g.detach()

        loss = (
            mae_loss * self.lambda_ssl_loss
            + lie_loss * self.lambda_lie_loss
            + euc_loss * self.lambda_euc_loss
        )

        batch_size = x1_aug1.shape[0]
        self.log_loss(f"{stage}_mae_loss", mae_loss, batch_size)
        self.log_loss(f"{stage}_lie_loss", lie_loss, batch_size)
        self.log_loss(f"{stage}_euc_loss", euc_loss, batch_size)
        self.log_loss(f"{stage}_loss", loss, batch_size)

        return loss


class MAELinearEval(ShapesBaseModel):
    """Trains a fresh linear classifier while keeping other weights frozen."""

    def __init__(
        self,
        learning_rate: float = 1e-1,
        optimizer: str = "adam",
        momentum: float = 0.9,
        weight_decay: float = 1e-4,
        top_k: Tuple[int, ...] = (1, 10),
        datamodule: Optional[pl.LightningDataModule] = None,
    ):
        super().__init__(top_k=top_k, datamodule=datamodule)

        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.momentum = momentum
        self.weight_decay = weight_decay

        self.backbone = self.load_backbone()
        self.feature_dim = 768
        # batch norm applied before linear classification head
        self.batch_norm = torch.nn.BatchNorm1d(self.feature_dim, affine=False)
        self.linear_classifier = torch.nn.Linear(self.feature_dim, self.num_classes)

    def load_backbone(self):
        mae_encoder = mae_module.vit_base_patch16()
        checkpoint = torch.hub.load_state_dict_from_url(
            "https://dl.fbaipublicfiles.com/mae/finetune/mae_finetuned_vit_base.pth"
        )
        mae_encoder.load_state_dict(checkpoint["model"], strict=False)
        mae_encoder_detached_head = torch.nn.Sequential(
            *list(mae_encoder.children())[:-1]
        )
        return mae_encoder_detached_head

    def on_train_epoch_start(self) -> None:
        self.backbone.eval()

    def forward(self, x):
        with torch.no_grad():
            feats = self.backbone(x).mean(dim=1)
        # apply batchnorm to features following implementation in paper
        # batchnorm not used for finetuning
        out = self.linear_classifier(self.batch_norm(feats))
        return out


class MAEFinetuner(ShapesBaseModel):
    """Trains a fresh linear classifier while keeping other weights frozen."""

    def __init__(
        self,
        learning_rate: float = 1e-1,
        optimizer: str = "adam",
        momentum: float = 0.9,
        weight_decay: float = 1e-4,
        top_k: Tuple[int, ...] = (1, 10),
        datamodule: Optional[pl.LightningDataModule] = None,
    ):
        super().__init__(top_k=top_k, datamodule=datamodule)

        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.momentum = momentum
        self.weight_decay = weight_decay

        self.backbone = self.load_backbone()
        self.feature_dim = 768
        self.linear_classifier = torch.nn.Linear(self.feature_dim, self.num_classes)

    def load_backbone(self):
        mae_encoder = mae_module.vit_base_patch16()
        checkpoint = torch.hub.load_state_dict_from_url(
            "https://dl.fbaipublicfiles.com/mae/finetune/mae_finetuned_vit_base.pth"
        )
        mae_encoder.load_state_dict(checkpoint["model"], strict=False)
        mae_encoder_detached_head = torch.nn.Sequential(
            *list(mae_encoder.children())[:-1]
        )
        return mae_encoder_detached_head

    def forward(self, x):
        feats = self.backbone(x).mean(dim=1)
        out = self.linear_classifier(feats)
        return out
