import pytorch_lightning as pl
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from typing import Dict, Optional, Tuple
import sys
from models.base_model import ShapesBaseModel
from models.simclr.simclr import SyncFunction, Projection
from pl_bolts.models.self_supervised.resnets import resnet18, resnet50
from pl_bolts.models.self_supervised import SimCLR as plSimCLR
import math


class SimCLRLie(plSimCLR):
    """Our main model"""

    def __init__(
        self,
        arch: str = "resnet50",
        hidden_mlp: int = 2048,
        feat_dim: int = 128,
        temperature: float = 0.1,
        first_conv: bool = True,
        maxpool1: bool = True,
    ):
        """
        Args:
            arch: backbone architecture
            hidden_mlp:
            feat_dim:
            temperature: the loss temperature
            first_conv:
            maxpool1:
        """
        # Here gpus, dataset, batch_size and num_samples are useless they are just there to call the pl_bots module
        super().__init__(
            gpus=1,
            dataset=None,
            batch_size=1,
            num_samples=0,
            arch=arch,
            hidden_mlp=hidden_mlp,
            feat_dim=feat_dim,
            temperature=temperature,
            first_conv=first_conv,
            maxpool1=maxpool1,
        )

    def save_hyperparameters(self):
        return plSimCLR.save_hyperparameters(self, ignore=["datamodule"])

    def compute_loss(self, z1, z2, z3, viewpair=False):
        """
        If viewpair, this is regular InfoNCE (augmented pair of views)
        If not viewpair this is Lie InfoNCE
        """
        if viewpair:
            loss = self.nt_xent_loss(z1, z2, self.temperature)
        else:
            loss = self.lie_nt_xent_loss(z1, z2, z3, self.temperature)

        return loss

    def lie_nt_xent_loss(self, out_1, out_2, out_3, temperature, eps=1e-6):
        """
        DOES NOT assume out_1 and out_2 are normalized
        out_1: [batch_size, dim]
        out_2: [batch_size, dim]
        out_3: [batch_size, dim]
        """
        # gather representations in case of distributed training
        # out_1_dist: [batch_size * world_size, dim]
        # out_2_dist: [batch_size * world_size, dim]
        # out_3_dist: [batch_size * world_size, dim]
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            out_1_dist = SyncFunction.apply(out_1)
            out_2_dist = SyncFunction.apply(out_2)
            out_3_dist = SyncFunction.apply(out_3)
        else:
            out_1_dist = out_1
            out_2_dist = out_2
            out_3_dist = out_3

        # out: [2 * batch_size, dim]
        # out_dist: [3 * batch_size * world_size, dim]
        out = torch.cat([out_1, out_2], dim=0)
        out_dist = torch.cat([out_1_dist, out_2_dist, out_3_dist], dim=0)

        # cov and sim: [2 * batch_size, 3 * batch_size * world_size]
        # neg: [2 * batch_size]
        cov = torch.mm(out, out_dist.t().contiguous())
        sim = torch.exp(cov / temperature)
        neg = sim.sum(dim=-1)

        row_sub = torch.exp(torch.norm(out, dim=-1) / temperature)
        neg = torch.clamp(neg - row_sub, min=eps)  # clamp for numerical stability

        # Positive similarity, pos becomes [2 * batch_size]
        pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
        pos = torch.cat([pos, pos], dim=0)
        loss = -torch.log(pos / (neg + eps)).mean()
        if loss < 0.0:
            print("Lie Contrastive loss can't be negative")
            raise ValueError("Lie Contrastive loss can't be negative")
        return loss

    def nt_xent_loss(
        self, out_1, out_2, temperature, eps=1e-6
    ):  # Same as SimLCR but works if not normalized too
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            out_1_dist = SyncFunction.apply(out_1)
            out_2_dist = SyncFunction.apply(out_2)
        else:
            out_1_dist = out_1
            out_2_dist = out_2

        out = torch.cat([out_1, out_2], dim=0)
        out_dist = torch.cat([out_1_dist, out_2_dist], dim=0)

        cov = torch.mm(out, out_dist.t().contiguous())
        sim = torch.exp(cov / temperature)
        neg = sim.sum(dim=-1)

        row_sub = torch.exp(torch.norm(out, dim=-1) / temperature)
        neg = torch.clamp(neg - row_sub, min=eps)  # clamp for numerical stability

        pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
        pos = torch.cat([pos, pos], dim=0)
        loss = -torch.log(pos / (neg + eps)).mean()
        if loss < 0.0:
            print("Lie Contrastive loss can't be negative")
            raise ValueError("Lie Contrastive loss can't be negative")
        return loss


class SimCLR(plSimCLR):
    """
    SimCLR model with a
    `compute_loss()` is to match the API of the remaining models.
    """

    def __init__(
        self,
        arch: str = "resnet50",
        hidden_mlp: int = 2048,
        feat_dim: int = 128,
        temperature: float = 0.1,
        first_conv: bool = True,
        maxpool1: bool = True,
    ):
        # Here gpus, dataset, batch_size and num_samples are useless they are just there to call the pl_bots module
        super().__init__(
            gpus=1,
            dataset=None,
            batch_size=1,
            num_samples=0,
            arch=arch,
            hidden_mlp=hidden_mlp,
            feat_dim=feat_dim,
            temperature=temperature,
            first_conv=first_conv,
            maxpool1=maxpool1,
        )

    def save_hyperparameters(self):
        return plSimCLR.save_hyperparameters(self, ignore=["datamodule"])

    def compute_loss(self, z1, z2):
        """
        Compute SSL loss on representations
        """
        loss = self.nt_xent_loss(z1, z2, self.temperature)
        return loss


class SimCLRLieModule(ShapesBaseModel):
    """Our main model's module"""

    def __init__(
        self,
        ssl_model: str = "SimCLRLie",
        ssl_arch: str = "resnet50",
        ssl_hidden_mlp: int = 2048,
        ssl_feat_dim: int = 128,
        ssl_temperature: float = 0.1,
        ssl_first_conv: bool = True,
        ssl_maxpool1: bool = True,
        learning_rate: float = 1e-1,
        optimizer: str = "adam",
        momentum: float = 0.9,
        weight_decay: float = 1e-4,
        dim_d: int = 100,
        lambda_ssl: float = 1.0,
        lambda_lie: float = 1.0,
        lambda_euc: float = 1.0,
        lambda_l2: float = 1.0,
        lambda_i: float = 1.0,
        datamodule: Optional[pl.LightningDataModule] = None,
        pretrained_ssl: bool = False,
        truncate_matrix_exp: bool = False,
        infer_t_bounds: bool = True,
        # top_k is here for loading compatibility
        top_k=None,
    ):
        super().__init__(top_k=None, datamodule=datamodule)

        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.ssl_feat_dim = ssl_feat_dim
        self.truncate_matrix_exp = truncate_matrix_exp
        # property used by Online Evaluator
        self.z_dim = ssl_hidden_mlp

        self.dim_d = dim_d
        self.g_matrix = None
        ssl_class = getattr(sys.modules[__name__], ssl_model)
        ssl_model = ssl_class(
            arch=ssl_arch,
            hidden_mlp=ssl_hidden_mlp,
            feat_dim=ssl_feat_dim,
            temperature=ssl_temperature,
            first_conv=ssl_first_conv,
            maxpool1=ssl_maxpool1,
        )

        self.pretrained_ssl = pretrained_ssl
        if pretrained_ssl:
            ssl_model = self.load_pretrained_simclr(ssl_model)
        self.ssl_model = ssl_model
        # create a basis of the Lie algebra
        dim_m = self.ssl_model.hidden_mlp

        # Initialize like Linear https://pytorch.org/docs/stable/generated/torch.nn.Linear.html?highlight=linear#torch.nn.Linear
        init_L_basis = torch.FloatTensor(size=(dim_d, dim_m, dim_m)).uniform_(
            -math.sqrt(1.0 / dim_m), math.sqrt(1.0 / dim_m)
        )
        self.L_basis = nn.parameter.Parameter(data=init_L_basis)

        # Inference model for t: takes as argument a concatenation of (z, z, delta) (one of them can be noised out) of size dim_m * 2 + 1
        self.h = nn.Sequential(
            nn.Linear(2 * dim_m + 1, 2 * dim_m + 1, bias=True),
            nn.LeakyReLU(),
            nn.Linear(2 * dim_m + 1, dim_d, bias=True),
        )
        self.lambda_lie = lambda_lie
        self.lambda_ssl = lambda_ssl
        self.lambda_euc = lambda_euc
        self.lambda_l2 = lambda_l2
        self.lambda_i = lambda_i
        if infer_t_bounds:
            # store only min and max
            self.t_bounds = nn.parameter.Parameter(
                data=torch.zeros(2, self.dim_d), requires_grad=False
            )

    def load_pretrained_simclr(self, ssl_model):
        weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
        print(" Loading SimCLR model from pl bots")
        ssl_model = ssl_model.load_from_checkpoint(
            weight_path, strict=False, **ssl_model._hparams
        )
        return ssl_model

    def on_train_start(self) -> None:
        self.t_bounds = self.t_bounds.to(self.device)

    def on_train_epoch_start(self) -> None:
        """
        Resets bounds at each epoch
        """
        self.t_bounds[:] = 0.0

    def on_validation_batch_start(self, batch, batch_idx, dataloader_idx) -> None:
        self.g_matrix = {}
        for val_loader_name in self.val_loader_names:
            self.g_matrix[val_loader_name] = None

    def on_train_batch_start(self, batch, batch_idx, dataloader_idx) -> None:
        self.g_matrix = {}
        for train_loader_name in self.train_loader_names:
            self.g_matrix[train_loader_name] = None

    def on_validation_epoch_end(self) -> None:
        """
        Saves first epoch checkpoint 
        """
        if self.current_epoch==0:
            self.trainer.save_checkpoint(f"first_{self.model_name}.ckpt")

    def compute_rep(self, x):
        z = self.ssl_model(x)  # no proj'
        return z

    def online_probe_forward(self, x1_online, x2_online, g, stage="train"):
        z1_online = self.forward(x1_online, use_identity=True)
        z2_online = self.forward(x2_online, use_identity=True)
        # reps is [batch_size, 2, dim_m]
        reps = torch.cat([z1_online, z2_online], dim=1)

        if "canonical" not in stage:
            # g is [batch_size, dim_m, dim_m]
            # reps is [batch_size, 2, dim_m]
            g_reps = torch.einsum("bij,bkj->bik", g, reps)
            # g_reps is [batch_size, dim_m, 2]
            g_reps = g_reps.permute(0, 2, 1)
            # g_reps is now [batch_size, 2, dim_m]
            reps = torch.cat([reps, g_reps], dim=1)
            # reps is now [batch_size, 4, dim_m]
        return reps

    def forward(
        self,
        x,
        num_neighbors=10,
        use_identity=False,
    ):
        z = self.compute_rep(x)

        if use_identity:
            # size (batch_size, 1, dim_m)
            return z[:, None, :]
        else:
            batch_size = x.size(0)
            # Sample Lie algebra coordinates
            # (b - a) * u + a to sample U(a,b)
            t = (self.t_bounds[1, :] - self.t_bounds[0, :])[None, None, :] * torch.rand(
                batch_size, num_neighbors, self.dim_d
            ).to(self.device) + self.t_bounds[0, :][None, None, :]
            z_primes = self.operate_fwd(z, t)
            # Append with regular z along 2nd dim
            return torch.cat([z[:, None, :], z_primes], dim=1)

    def L2_constraint(self, z, z_prime, delta, t):
        """
        Computes the L2 constraint on t with a chosen similarity function
        The smaller sim(z,z',delta), the smaller the t -> the bigger the constraint
        """
        # Version 1: 1 - delta such that the smaller the delta, the smaller the t -> the bigger the constraint
        sim = 1 / (1 + torch.exp(delta.abs()))
        loss = sim * (t ** 2).sum(-1)
        return loss.mean()

    def ind_constraint(self):
        """
        Constraint to ensure that L_basis are linear independent matrices
        """
        # TODO
        return 0.

    def matrix_exponential(self, A):
        if self.truncate_matrix_exp:
            # uses a linearization
            return 1.0 + A
        return torch.matrix_exp(A)

    def operate(self, z, z_prime, delta, t=None):

        if t is None:
            if z_prime is None:
                raise ValueError("z_prime and t cannot both be None")
            v = torch.cat([z, z_prime, delta], axis=-1)
            # Breaks back propagation wrt encoder on z (just learn the weights of h)
            t = 2 * torch.sigmoid(self.h(v.detach().clone())) - 1

        # Hard code that t is 0 where delta is 0
        t = t * (1 - (delta == 0).float())

        # t (batch_size, dim_d)
        # L_basis (dim_d, dim_m, dim_m)
        lie_matrix = (t[..., None, None] * self.L_basis[None, ...]).sum(1)
        g = self.matrix_exponential(lie_matrix)

        # g (batch_size, dim_m, dim_m)
        # z (batch_size, dim_m)
        z_prime_hat = torch.bmm(g, z[..., None]).squeeze(-1)
        return z_prime_hat, t, g

    def operate_fwd(self, z, t):

        # t (batch_size, num_neighbors, dim_d)
        # L_basis (dim_d, dim_m, dim_m)

        batch_size, num_neighbors, dim_d = t.size()
        _, dim_m, _ = self.L_basis.size()
        lie_matrix = torch.mm(
            t.view(batch_size * num_neighbors, dim_d),
            self.L_basis.view(dim_d, dim_m * dim_m),
        ).view(batch_size, num_neighbors, dim_m, dim_m)
        # lie_matrix = (t[..., None, None] * self.L_basis[None, None, ...]).sum(2)
        g = self.matrix_exponential(lie_matrix)
        # g (batch_size, num_neighbors, dim_m, dim_m)
        # z (batch_size, dim_m)
        z_primes = torch.einsum("bijk,bk->bij", g, z)
        return z_primes

    def compute_simclr_infonce(self, z1, z2):

        return self.ssl_model.compute_loss(
            self.ssl_model.projection(z1),
            self.ssl_model.projection(z2),
            None,
            viewpair=True,
        )

    def compute_lie_infonce(self, z1, z2, z3):

        return self.ssl_model.compute_loss(
            self.ssl_model.projection(z1),
            self.ssl_model.projection(z2),
            self.ssl_model.projection(z3),
            viewpair=False,
        )

    def lie_loss(self, z1_aug, z2_aug, delta, t=None):

        z2_aug_hat, t, g = self.operate(z1_aug, z2_aug, delta, t)

        lie_loss = self.compute_lie_infonce(z2_aug, z2_aug_hat, z1_aug)

        euc_loss = ((z2_aug - z2_aug_hat) ** 2).sum(-1).mean(0)
        l2_loss = self.L2_constraint(None, None, delta, t)  # TODO not using any z yet

        return lie_loss, euc_loss, l2_loss, t, g

    def shared_step(self, batch, stage: str = "train", return_terms=False):

        x1tuple, x2tuple, _, _, _, delta = batch
        x1_aug1, x1_aug2, x1_online = x1tuple
        x2_aug1, x2_aug2, x2_online = x2tuple
        z1_aug1 = self.compute_rep(x1_aug1)
        z1_aug2 = self.compute_rep(x1_aug2)
        z2_aug1 = self.compute_rep(x2_aug1)
        z2_aug2 = self.compute_rep(x2_aug2)
        if "canonical" in stage:
            ssl_loss = 2 * self.compute_simclr_infonce(z1_aug1, z1_aug2)
            lie_loss = 0.
            euc_loss = 0.
            l2_loss = 0.
            ind_constraint = 0.

            loss = ssl_loss

        else: # Note: in the case of val_combined, there is both canonical and diverse and yet we go through this "else"
            ssl_loss = self.compute_simclr_infonce(
                z1_aug1, z1_aug2
            ) + self.compute_simclr_infonce(z2_aug1, z2_aug2)
            lie_loss, euc_loss, l2_loss, t, g = self.lie_loss(z1_aug1, z2_aug1, delta)
            if stage in [
                "train_diverse_2d",
                "train_diverse_3d",
            ]:  # No need to include canonical, it's always t = 0, and we initialize to 0 so it's been considered
                t_min = torch.min(t.detach().clone(), dim=0)[
                    0
                ]  # Not backprop to t, not sure we need clone but just to be sure
                t_max = torch.max(t.detach().clone(), dim=0)[0]  # TODO: include -t ?
                self.t_bounds[0, :] = torch.minimum(self.t_bounds[0, :], t_min)
                self.t_bounds[1, :] = torch.maximum(self.t_bounds[1, :], t_max)

            # Workaround for online probing
            self.g_matrix[stage] = g.detach().clone()
            ind_constraint = self.ind_constraint()

            loss = (
                self.lambda_ssl * ssl_loss
                + self.lambda_lie * lie_loss
                + self.lambda_euc * euc_loss
                + self.lambda_l2 * l2_loss
                + self.lambda_i * ind_constraint
            )

        batch_size = x1_aug1.shape[0]
        self.log(
            f"{stage}_loss",
            loss,
            sync_dist=True,
            batch_size=batch_size,  # loader names are used instead
            add_dataloader_idx=False,
            on_step=True,
            on_epoch=True,
        )

        names = ["ssl_loss", "lie_loss", "euc_loss", "l2_loss", "ind_constraint"]
        terms = (ssl_loss, lie_loss, euc_loss, l2_loss, ind_constraint)
        for k,v in dict(zip(names, terms)).items():
            self.log(
                f"{stage}_{k}",
                v,
                sync_dist=True,
                batch_size=batch_size,  
                add_dataloader_idx=False,
                on_step=True,
                on_epoch=True,
            )

        if return_terms:
            return loss, terms
        else:
            return loss

    def training_step(self, loaders, loader_idx):
        terms = {"loss": 0.0, "ssl": 0.0, "lie": 0.0, "euc": 0.0, "l2": 0.0, "ind": 0.0}
        for loader_name in loaders:
            batch = loaders[loader_name]
            loss_i, loss_terms = self.shared_step(
                batch, stage=loader_name, return_terms=True
            )
            terms["loss"] += loss_i
            terms["ssl"] += loss_terms[0]
            terms["lie"] += loss_terms[1]
            terms["euc"] += loss_terms[2]
            terms["l2"] += loss_terms[3]
            terms["ind"] += loss_terms[4]
        # I think this works for returning multiple values (I might use them later)
        # https://github.com/PyTorchLightning/pytorch-lightning/issues/3041
        return terms


class SimCLRLieBottleneckModule(SimCLRLieModule):
    """ """

    def __init__(self, dim_m: int = 512, **kwargs):
        super().__init__(**kwargs)

        self.dim_m = dim_m
        # property used by Online Evaluator
        self.z_dim = dim_m
        # Override all Lie stuff with dim_m
        init_L_basis = torch.FloatTensor(size=(self.dim_d, dim_m, dim_m)).uniform_(
            -math.sqrt(1.0 / dim_m), math.sqrt(1.0 / dim_m)
        )
        self.L_basis = nn.parameter.Parameter(data=init_L_basis)

        self.h = nn.Sequential(
            nn.Linear(2 * dim_m + 1, 2 * dim_m + 1, bias=True),
            nn.LeakyReLU(),
            nn.Linear(2 * dim_m + 1, self.dim_d, bias=True),
        )

        self.bottleneck = nn.Linear(self.ssl_model.hidden_mlp, dim_m)
        # Ditch pretrained and use fresh projection to add smaller dim +
        # normalization to hypersphere
        self.ssl_model.projection = Projection(
            input_dim=dim_m,
            hidden_dim=dim_m,
            output_dim=self.ssl_model.feat_dim,
        )

    def compute_rep(self, x):

        z = self.bottleneck(self.ssl_model(x))
        return z


class SimCLRFramesModule(SimCLRLieModule):
    """
    Ablation of the Lie Model without transformations learning
    Should be called with SimCLR backbone

    Encourages both frames and augmentations to be similar
    """

    def __init__(self, **kwargs):
        if "ssl_model" in kwargs:
            super().__init__(**kwargs)
        else:
            super().__init__(ssl_model="SimCLR", **kwargs)
        # property used by online evaluator
        self.z_dim = 2048

    def forward(
        self,
        x,
        use_identity=False,  # only for compatibility
    ):
        z = self.ssl_model(x)
        # size (batch_size, 1, dim_m)
        return z[:, None, :]

    def compute_rep(self, x):

        z = self.ssl_model(x)  # size (batch_size, dim_m)
        return z

    def shared_step(self, batch, stage: str = "train"):

        x1tuple, x2tuple, _, _, _, delta = batch
        x1_aug1, x1_aug2, x1_online = x1tuple
        x2_aug1, x2_aug2, x2_online = x2tuple
        z1_aug1 = self.ssl_model.projection(self.compute_rep(x1_aug1))
        z1_aug2 = self.ssl_model.projection(self.compute_rep(x1_aug2))
        z2_aug1 = self.ssl_model.projection(self.compute_rep(x2_aug1))
        z2_aug2 = self.ssl_model.projection(self.compute_rep(x2_aug2))

        if "canonical" in stage:
            loss = 2 * self.ssl_model.compute_loss(z1_aug1, z1_aug2)
        else:
            loss = self.ssl_model.compute_loss(z1_aug1, z1_aug2)
            +self.ssl_model.compute_loss(z2_aug1, z2_aug2)
            +self.ssl_model.compute_loss(z1_aug1, z2_aug1)

        batch_size = x1_aug1.shape[0]
        self.log(
            f"{stage}_loss",
            loss,
            sync_dist=True,
            batch_size=batch_size,  # loader names are used instead
            add_dataloader_idx=False,
            on_step=True,
            on_epoch=True,
        )

        return loss

    def training_step(self, loaders, loader_idx):
        for loader_name in loaders:
            batch = loaders[loader_name]
            loss = self.shared_step(batch, stage=loader_name)
        return loss

    def online_probe_forward(self, x1_online, x2_online, g=None, stage="train"):
        z1_online = self.forward(x1_online, use_identity=True)
        z2_online = self.forward(x2_online, use_identity=True)
        # reps is batch_size, 2, dim_m
        reps = torch.cat([z1_online, z2_online], dim=1)
        return reps


class SimCLRFramesMoreParamsModule(SimCLRFramesModule):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # additional projection layer to match parameter count of Lie with 1 generator
        self.additional_projection = nn.Linear(2048, 2048)

    def compute_rep(self, x):
        z = self.ssl_model(x)  # size (batch_size, dim_m)
        z = self.additional_projection(z)
        return z


class BaseLinearEval(ShapesBaseModel):
    """Trains a fresh linear classifier while keeping other weights frozen
    Backbone Lie algebra model
    Classifies on pre-projection (2048)
    """

    def __init__(
        self,
        learning_rate: float = 1e-1,
        optimizer: str = "adam",
        momentum: float = 0.9,
        weight_decay: float = 1e-4,
        top_k: Tuple[int, ...] = (1, 10),
        num_neighbors: int = 10,
        use_identity: bool = False,
        lie_module_path: Optional[str] = None,
        datamodule: Optional[pl.LightningDataModule] = None,
    ):

        super().__init__(top_k=top_k, datamodule=datamodule)

        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.num_neighbors = num_neighbors
        self.use_identity = use_identity

        self.lie_module_path = lie_module_path
        self.backbone = self.load_backbone()
        self.backbone.eval()
        self.linear_classifier = torch.nn.Linear(
            self.backbone.ssl_model.hidden_mlp, self.num_classes
        )

    def load_backbone(self):
        if self.lie_module_path is not None:
            lie_ssl = SimCLRLieModule.load_from_checkpoint(
                self.lie_module_path,
                datamodule=self.datamodule,
                pretrained_ssl=False,
            )
            print("Loaded backbone from", self.lie_module_path)
        else:
            raise NotImplementedError("Lie SSL path must be given")
        return lie_ssl

    def on_train_epoch_start(self) -> None:
        self.backbone.eval()

    def loss_function(self, y_hat, y):
        # Use Nll since y_hat is a log probability
        return F.nll_loss(y_hat, y)

    def shared_step(self, batch, stage: str = "train"):
        # Same as ShapesBaseModel but using the NLL since a log prob is outputed by the forward model (due to averaging)
        x, y, _ = batch
        y_hat = self(x)
        loss = self.loss_function(y_hat, y)
        batch_size = x.shape[0]
        self.log(
            f"{stage}_loss",
            loss,
            sync_dist=True,
            # loader names are used instead
            add_dataloader_idx=False,
            batch_size=batch_size,
        )
        for k in self.top_k:
            accuracy_metric = getattr(self, f"{stage}_top_{k}_accuracy")
            accuracy_metric(
                F.softmax(y_hat, dim=-1), y
            )  # Note: we already have log p if using LieModule (but it does not hurt) but keep for SimCLRFrames module case
            self.log(
                f"{stage}_top_{k}_accuracy",
                accuracy_metric,
                prog_bar=True,
                sync_dist=True,
                on_epoch=True,
                on_step=False,
                batch_size=batch_size,
                # loader names are used instead
                add_dataloader_idx=False,
            )
        return loss

    def forward(self, x):
        with torch.no_grad():
            feats = self.backbone.forward(x, self.num_neighbors, self.use_identity)
        out = self.linear_classifier(feats)
        # Log softmax over class probabilies for each neighbor
        out = F.log_softmax(out, dim=-1)
        # log proba = log mean over num_neighbors of probas
        out = torch.logsumexp(out, dim=1) + math.log(1 / float(feats.size(1)))
        return out


class LinearEvalBackpropLosses(BaseLinearEval):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def loss_function(self, y_hat, y):
        # Regular CE since we have logits
        return F.cross_entropy(y_hat, y)

    def unpack_logits(self, logits: Tensor):
        z_logits = logits[:, 0, :]
        if logits.shape[1] > 1:
            neighbours_logits = logits[:, 1:, :]
        else:
            neighbours_logits = None
        return z_logits, neighbours_logits

    def shared_step(self, batch, stage: str = "train"):
        x, y, _ = batch
        y_hat = self(x)
        z_logits, neighbours_logits = self.unpack_logits(y_hat)
        z_loss = self.loss_function(z_logits, y)

        neighbors_loss = 0.0
        if type(neighbours_logits) is Tensor:
            n_neigbours = neighbours_logits.shape[1]
            for k in range(n_neigbours):
                neighbors_loss += self.loss_function(neighbours_logits[:, k, :], y)
            neighbors_loss = (
                neighbors_loss / n_neigbours
            )  # nornalization to make it independent of n_neigbours

        loss = z_loss + neighbors_loss

        self.log(
            f"{stage}_loss",
            loss,
            sync_dist=True,
            # loader names are used instead
            add_dataloader_idx=False,
        )
        for k in self.top_k:
            accuracy_metric = getattr(self, f"{stage}_top_{k}_accuracy")
            # we use only z for accuracy
            accuracy_metric(F.softmax(z_logits, dim=-1), y)
            self.log(
                f"{stage}_top_{k}_accuracy",
                accuracy_metric,
                prog_bar=True,
                on_epoch=True,
                on_step=False,
                # loader names are used instead
                add_dataloader_idx=False,
            )
        return loss

    # Forward function outputs logits
    def forward(self, x):
        with torch.no_grad():
            feats = self.backbone.forward(x, self.num_neighbors, self.use_identity)
        out = self.linear_classifier(feats)
        return out


class LieLinearEvalBneck(BaseLinearEval):
    """Trains a fresh linear classifier while keeping other weights frozen
    Backbone Lie algebra model with Bneck
    Classifies after bottleneck
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # Override linear_classifier for correct dim (dim_m)
        self.linear_classifier = torch.nn.Linear(self.backbone.dim_m, self.num_classes)

    def load_backbone(self):
        if self.lie_module_path is not None:
            lie_ssl = SimCLRLieBottleneckModule.load_from_checkpoint(
                self.lie_module_path,
                datamodule=self.datamodule,
                pretrained_ssl=False,
            )
            print("Loaded backbone from", self.lie_module_path)
        else:
            raise NotImplementedError("Lie SSL path must be given")
        return lie_ssl


class LieLinearEvalNoLieBneck(BaseLinearEval):
    """Trains a fresh linear classifier while keeping other weights frozen"""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.linear_classifier = torch.nn.Linear(
            self.backbone.ssl_model.hidden_mlp, self.num_classes
        )

    def forward(self, x):
        with torch.no_grad():
            feats = self.backbone.ssl_model(x)[:, None, :]
        out = self.linear_classifier(feats)
        # Log softmax over class probabilies for each neighbor
        out = F.log_softmax(out, dim=-1)
        # log proba = log mean over num_neighbors of probas
        out = torch.logsumexp(out, dim=1) + math.log(1 / float(feats.size(1)))
        return out


class SimCLRLinearEval(BaseLinearEval):
    """Trains a fresh linear classifier while keeping other weights frozen
    Backbone regular SIMCLR model
    Classifies on pre-projection (2048)
    """

    def __init__(self, **kwargs):

        super().__init__(**kwargs)

    def loss_function(self, y_hat, y):
        # Regular CE since we have logits
        return F.cross_entropy(y_hat, y)

    def load_backbone(self):
        if self.lie_module_path is not None:
            ssl = SimCLRFramesModule.load_from_checkpoint(
                self.lie_module_path, pretrained_ssl=False, datamodule=self.datamodule
            )
            print("Loaded backbone from", self.lie_module_path)
        else:
            print(
                "SSL path not given, using SimCLR and SimCLR ImageNet pretrained in backbone"
            )
            ssl = SimCLRFramesModule(pretrained_ssl=True, datamodule=self.datamodule)
        return ssl

    # Forward function outputs logits
    def forward(self, x):
        with torch.no_grad():
            feats = self.backbone.forward(x).squeeze(1)
        out = self.linear_classifier(feats)
        return out

class SimCLRMoreParamsLinearEval(SimCLRLinearEval):
    """Trains a fresh linear classifier while keeping other weights frozen
    Backbone regular SIMCLR model
    Classifies on pre-projection (2048)
    """

    def __init__(self, **kwargs):

        super().__init__(**kwargs)

    def load_backbone(self):
        if self.lie_module_path is not None:
            ssl = SimCLRFramesMoreParamsModule.load_from_checkpoint(
                self.lie_module_path, pretrained_ssl=False, datamodule=self.datamodule
            )
            print("Loaded backbone from", self.lie_module_path)
        else:
            print(
                "SSL path not given, using SimCLR and SimCLR ImageNet pretrained in backbone"
            )
            ssl = SimCLRFramesMoreParamsModule(pretrained_ssl=True, datamodule=self.datamodule)
        return ssl


    