"""
Based on https://github.com/facebookresearch/vicreg/blob/main/main_vicreg.py
"""
from models.base_model import ShapesBaseModel, LieModelMixins
import pytorch_lightning as pl
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from pl_bolts.models.self_supervised.resnets import resnet18, resnet50
from models.simclr.simclr import SyncFunction
from typing import Dict, Optional, Tuple
import numpy as np
# import models.vicreg.resnet as resnet

def hack_load_ddp_model(state_dict):
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] # remove module.
        new_state_dict[name] = v
    return new_state_dict

class VICRegLie(nn.Module):
    def __init__(
        self,
        arch: str = "resnet50",
        maxpool1: bool = True,
        mlp_tuple_dim="8192-8192-8192",
        sim_coeff=25.0,
        std_coeff=25.0,
        cov_coeff=1.0,
    ):
        super().__init__()
        self.num_features = int(mlp_tuple_dim.split("-")[-1])
        self.arch = arch
        self.maxpool1 = maxpool1
        if self.arch == "resnet18":
            backbone = resnet18
            self.embedding_dim = 1024
        elif self.arch == "resnet50":
            backbone = resnet50
            self.embedding_dim = 2048

        self.backbone = backbone(
            first_conv=True,
            maxpool1=self.maxpool1,
            return_all_feature_maps=False,
        )
        # self.backbone, self.embedding_dim= resnet.__dict__[self.arch](
        #     zero_init_residual=True
        # )

        self.projector = Projector(mlp_tuple_dim, self.embedding_dim)

    def forward(self, x):
        # bolts resnet returns a list
        return self.backbone(x)[-1]


class VICRegLieModule(ShapesBaseModel, LieModelMixins):
    def __init__(
        self,
        learning_rate: float = 1e-5,
        optimizer: str = "adam",
        momentum: float = 0.9,
        weight_decay: float = 1e-4,
        lambda_lie: float = 1.0,
        lambda_ssl: float = 1.0,
        lambda_euc: float = 1.0,
        lambda_l2: float = 1.0,
        ssl_temperature: float = 0.1,
        dim_d: int = 100,
        top_k: Tuple[int, ...] = (1, 10),
        datamodule: Optional[pl.LightningDataModule] = None,
        infer_t_bounds: bool = True,
        truncate_matrix_exp: bool = False,
        arch: str = "resnet50",
        maxpool1: bool = True,
        mlp_tuple_dim="8192-8192-8192",
        sim_coeff=25.0,
        std_coeff=25.0,
        cov_coeff=1.0,
        temperature: float = 0.1,
    ):

        super().__init__(top_k=top_k, datamodule=datamodule)

        self.sim_coeff = sim_coeff
        self.std_coeff = std_coeff
        self.cov_coeff = cov_coeff

        self.arch = arch
        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.momentum = momentum
        self.weight_decay = weight_decay

        self.lambda_lie = lambda_lie
        self.lambda_ssl = lambda_ssl
        self.lambda_euc = lambda_euc
        self.lambda_l2 = lambda_l2
        self.temperature = ssl_temperature
        self.truncate_matrix_exp = truncate_matrix_exp
        # number of lie generators
        self.dim_d = dim_d
        vicreg = VICRegLie(
            arch, maxpool1, mlp_tuple_dim, sim_coeff, std_coeff, cov_coeff
        )
        self.ssl_model = self.load_pretained_vicreg(vicreg)
        self.embedding_dim = self.ssl_model.embedding_dim
        # for comptability with attribute of online probe
        self.z_dim = self.embedding_dim

        # stage to g_matrix. This is reset for every training batch
        self.g_matrix = dict()
        self.L_basis = self.initialize_L_basis(self.dim_d, self.embedding_dim)

        self.t_inference_network = self.create_t_inference_network(
            self.dim_d, self.embedding_dim
        )
        if infer_t_bounds:
            # store only min and max
            self.t_bounds = torch.nn.parameter.Parameter(
                data=torch.zeros(2, self.dim_d), requires_grad=False
            )

    def load_pretained_vicreg(self, vicreg):
        weight_path = f"https://dl.fbaipublicfiles.com/vicreg/{self.arch}_fullckpt.pth"
        checkpoint = torch.hub.load_state_dict_from_url(weight_path)
        new_state_dict = hack_load_ddp_model(checkpoint["model"])
        vicreg.load_state_dict(new_state_dict, strict=False)
        return vicreg

    def online_probe_forward(self, x1_online, x2_online, g, stage: str = "train"):
        z1_online = self.forward(x1_online) # [batch_size, dim_m]
        z2_online = self.forward(x2_online) # [batch_size, dim_m]
        # reps is [batch_size, 2, dim_m]
        reps = torch.stack([z1_online, z2_online], dim=1)

        if "canonical" not in stage:
            # g is [batch_size, dim_m, dim_m]
            # reps is [batch_size, 2, dim_m]
            g_reps = torch.einsum("bij,bkj->bik", g, reps)
            # g_reps is [batch_size, dim_m, 2]
            g_reps = g_reps.permute(0, 2, 1)
            # g_reps is now [batch_size, 2, dim_m]
            reps = torch.cat([reps, g_reps], dim=1)
            # reps is now [batch_size, 4, dim_m]
        return reps

    def forward(self, x) -> Tensor:

        return self.compute_rep(x)

    def compute_rep(self, x) -> Tensor:
        z = self.ssl_model(x)  # no proj'
        return z

    def compute_loss(self, z1, z2, z3, viewpair=False):
        """
        If viewpair, this is regular VicReg (augmented pair of views)
        If not viewpair this is Lie InfoNCE
        """
        if viewpair:
            loss = self.vicreg_loss(z1, z2)
        else:
            loss = self.lie_nt_xent_loss(z1, z2, z3, self.temperature)

        return loss

    def compute_vicreg_loss(self, z1, z2):

        return self.compute_loss(
            self.ssl_model.projector(z1),
            self.ssl_model.projector(z2),
            None,
            viewpair=True,
        )

    def compute_lie_infonce(self, z1, z2, z3):

        # Adds additional normalization before projector
        return self.compute_loss(
            F.normalize(self.ssl_model.projector(z1), dim=1),
            F.normalize(self.ssl_model.projector(z2), dim=1),
            F.normalize(self.ssl_model.projector(z3), dim=1),
            viewpair=False,
        )

    def vicreg_loss(self, out_1, out_2):
        # x = self.ssl_model.projector(self.ssl_model.backbone(x))
        # y = self.ssl_model.projector(self.ssl_model.backbone(y))
        # out_1 and out_2 have already been processed by the model's SSL Resnet backbone + the Projector
        repr_loss = F.mse_loss(out_1, out_2)

        if torch.distributed.is_available() and torch.distributed.is_initialized():
            out_1_dist = SyncFunction.apply(out_1)
            out_2_dist = SyncFunction.apply(out_2)
        else:
            out_1_dist = out_1
            out_2_dist = out_2

        batch_size = out_1_dist.shape[
            0
        ]  

        std_out_1 = torch.sqrt(out_1_dist.var(dim=0) + 0.0001)
        std_out_2 = torch.sqrt(out_2_dist.var(dim=0) + 0.0001)
        std_loss = (
            torch.mean(F.relu(1 - std_out_1)) / 2
            + torch.mean(F.relu(1 - std_out_2)) / 2
        )

        cov_out_1 = (out_1_dist.T @ out_1_dist) / (batch_size - 1)
        cov_out_2 = (out_2_dist.T @ out_2_dist) / (batch_size - 1)
        cov_loss = off_diagonal(cov_out_1).pow_(2).sum().div(
            self.ssl_model.num_features
        ) + off_diagonal(cov_out_2).pow_(2).sum().div(self.ssl_model.num_features)

        loss = (
            self.sim_coeff * repr_loss
            + self.std_coeff * std_loss
            + self.cov_coeff * cov_loss
        )
        return loss

    def predict_z2_hat(self, z1, z2, delta) -> Tuple[Tensor, Tensor, Tensor]:

        t = self.infer_t(z1, z2, delta)
        lie_matrix = (t[..., None, None] * self.L_basis[None, ...]).sum(1)
        g = self.matrix_exponential(lie_matrix)

        z2_hat = torch.bmm(g, z1[..., None]).squeeze(-1)
        return z2_hat, g, t

    def L2_constraint(self, delta, t):
        """
        Computes the L2 constraint on t with a chosen similarity function
        """
        # Version 1: 1 - delta such that the smaller the delta, the smaller the t -> the bigger the constraint
        sim = 1 / (1 + torch.exp(delta.abs()))
        loss = sim * (t ** 2).sum(-1)
        return loss.mean()

    def shared_step(self, batch, stage: str = "train", return_terms=False):
        x1tuple, x2tuple, _, _, _, delta = batch
        x1_aug1, x1_aug2, x1_online = x1tuple
        x2_aug1, x2_aug2, x2_online = x2tuple

        z1_aug1 = self.compute_rep(x1_aug1)
        z1_aug2 = self.compute_rep(x1_aug2)
        z2_aug1 = self.compute_rep(x2_aug1)
        z2_aug2 = self.compute_rep(x2_aug2)

        if "canonical" in stage:
            ssl_loss = 2 * self.compute_vicreg_loss(z1_aug1, z1_aug2)
            lie_loss = 0.0
            euc_loss = 0.0
            l2_loss = 0.0
            loss = ssl_loss
            self.g_matrix[stage] = None

        else:
            ssl_loss = self.compute_vicreg_loss(
                z1_aug1, z1_aug2
            ) + self.compute_vicreg_loss(z2_aug1, z2_aug2)

            z2_aug1_hat, g, t = self.predict_z2_hat(z1_aug1, z2_aug1, delta)
            lie_loss = self.compute_lie_infonce(z2_aug1, z2_aug1_hat, z1_aug1)
            euc_loss = ((z2_aug1 - z2_aug1_hat) ** 2).sum(-1).mean(0)
            l2_loss = self.L2_constraint(delta, t)

            # for online probing
            # no need to clone based on PyTorch docs
            self.g_matrix[stage] = g.detach()

            loss = (
                self.lambda_ssl * ssl_loss
                + self.lambda_lie * lie_loss
                + self.lambda_euc * euc_loss
                + self.lambda_l2 * l2_loss
            )

        batch_size = x1_aug1.shape[0]
        self.log_loss(f"{stage}_vicreg_loss", ssl_loss, batch_size)
        self.log_loss(f"{stage}_lie_loss", lie_loss, batch_size)
        self.log_loss(f"{stage}_euc_loss", euc_loss, batch_size)
        self.log_loss(f"{stage}_loss", loss, batch_size)

        return loss

    def log_loss(self, name: str, value: Tensor, batch_size: int):
        self.log(
            name,
            value,
            sync_dist=True,
            batch_size=batch_size,  # loader names are used instead
            add_dataloader_idx=False,
            on_step=True,
            on_epoch=True,
        )

    def transform(self, z, t=None) -> Tensor:
        """Transforms z by applying g(z, t)

        If t is None, then a randomly sampled t is generated.
        """
        if t is None:
            batch_size = z.shape[0]
            t = torch.Tensor(np.random.uniform(-1.0, 1.0, (batch_size, self.dim_d))).to(
                z.device
            )

        lie_matrix = (t[..., None, None] * self.L_basis[None, ...]).sum(1)
        g = self.matrix_exponential(lie_matrix)

        z_transformed = torch.bmm(g, z[..., None]).squeeze(-1)
        return z_transformed

    def training_epoch_end(self, outs):
        # log epoch metric
        tmp = self.online_train_canonical_top_1_accuracy.clone()
        acc = tmp.compute().item()
        print(f"Online train canonical {acc}")


def Projector(mlp_tuple_dim, embedding_dim):
    mlp_spec = f"{embedding_dim}-{mlp_tuple_dim}"
    layers = []
    f = list(map(int, mlp_spec.split("-")))
    for i in range(len(f) - 2):
        layers.append(nn.Linear(f[i], f[i + 1]))
        layers.append(nn.BatchNorm1d(f[i + 1]))
        layers.append(nn.ReLU(True))
    layers.append(nn.Linear(f[-2], f[-1], bias=False))
    return nn.Sequential(*layers)


def off_diagonal(x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
