from models.base_model import ShapesBaseModel
from models.lie_ssl.shapes_model import (
    SimCLRFramesModule,
    SimCLRFramesMoreParamsModule,
    SimCLRLieModule,
)
import torch
from typing import Optional, Tuple
import torch.nn.functional as F
import pytorch_lightning as pl
import math
from torch import Tensor


class BaseFineTuning(ShapesBaseModel):
    """Trains a fresh linear classifier while allowing all model weights to update."""

    def __init__(
        self,
        learning_rate: float = 1e-1,
        optimizer: str = "adam",
        momentum: float = 0.9,
        weight_decay: float = 1e-4,
        top_k: Tuple[int, ...] = (1, 10),
        num_neighbors: int = 10,
        use_identity: bool = False,
        lie_module_path: Optional[str] = None,
        datamodule: Optional[pl.LightningDataModule] = None,
    ):

        super().__init__(top_k=top_k, datamodule=datamodule)

        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.num_neighbors = num_neighbors
        self.use_identity = use_identity

        self.lie_module_path = lie_module_path
        self.backbone = self.load_backbone()
        try:
            z_dim = self.backbone.ssl_model.hidden_mlp
        except AttributeError:
            z_dim = self.backbone.z_dim
        self.linear_classifier = torch.nn.Linear(z_dim, self.num_classes)

    def load_backbone(self):
        raise NotImplementedError

    def loss_function(self, y_hat, y):
        return F.cross_entropy(y_hat, y)

    def shared_step(self, batch, stage: str = "train"):
        # Same as ShapesBaseModel but using the NLL since a log prob is outputed by the forward model (due to averaging)
        x, y, _ = batch
        y_hat = self(x)
        loss = self.loss_function(y_hat, y)
        batch_size = x.shape[0]
        self.log(
            f"{stage}_loss",
            loss,
            sync_dist=True,
            # loader names are used instead
            add_dataloader_idx=False,
            batch_size=batch_size,
        )
        for k in self.top_k:
            accuracy_metric = getattr(self, f"{stage}_top_{k}_accuracy")
            accuracy_metric(
                F.softmax(y_hat, dim=-1), y
            )  # Note: we already have log p if using LieModule (but it does not hurt) but keep for SimCLRFrames module case
            self.log(
                f"{stage}_top_{k}_accuracy",
                accuracy_metric,
                prog_bar=True,
                sync_dist=True,
                on_epoch=True,
                on_step=False,
                batch_size=batch_size,
                # loader names are used instead
                add_dataloader_idx=False,
            )
        return loss

    def forward(self, x):
        raise NotImplementedError


class SimCLRFramesFinetuning(BaseFineTuning):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def load_backbone(self):
        if self.lie_module_path is not None:
            ssl = SimCLRFramesModule.load_from_checkpoint(
                self.lie_module_path, pretrained_ssl=False, datamodule=self.datamodule
            )
            print("Loaded backbone from", self.lie_module_path)
        else:
            print(
                "SSL path not given, using SimCLR and SimCLR ImageNet pretrained in backbone"
            )
            ssl = SimCLRFramesModule(pretrained_ssl=True, datamodule=self.datamodule)
        return ssl

    def forward(self, x):
        feats = self.backbone.forward(x).squeeze(1)
        out = self.linear_classifier(feats)
        return out


class SimCLRFramesMoreParamsFinetuning(SimCLRFramesFinetuning):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def load_backbone(self):
        if self.lie_module_path is not None:
            ssl = SimCLRFramesMoreParamsModule.load_from_checkpoint(
                self.lie_module_path, pretrained_ssl=False, datamodule=self.datamodule
            )
            print("Loaded backbone from", self.lie_module_path)
        else:
            print(
                "SSL path not given, using SimCLR and SimCLR ImageNet pretrained in backbone"
            )
            ssl = SimCLRFramesMoreParamsModule(
                pretrained_ssl=True, datamodule=self.datamodule
            )
        return ssl


class SimCLRLieFinetuning(BaseFineTuning):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def load_backbone(self):
        if self.lie_module_path is not None:
            lie_ssl = SimCLRLieModule.load_from_checkpoint(
                self.lie_module_path,
                datamodule=self.datamodule,
                pretrained_ssl=False,
            )
            print("Loaded backbone from", self.lie_module_path)
        else:
            raise NotImplementedError("Lie SSL path must be given")
        return lie_ssl

    def loss_function(self, y_hat, y):
        # Use Nll since y_hat is a log probability
        return F.nll_loss(y_hat, y)

    def forward(self, x):
        feats = self.backbone.forward(x, self.num_neighbors, self.use_identity)
        out = self.linear_classifier(feats)
        # Log softmax over class probabilies for each neighbor
        out = F.log_softmax(out, dim=-1)
        # log proba = log mean over num_neighbors of probas
        out = torch.logsumexp(out, dim=1) + math.log(1 / float(feats.size(1)))
        return out


class SimCLRLieFinetuningBackpropLosses(SimCLRLieFinetuning):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def loss_function(self, y_hat, y):
        # Regular CE since we have logits
        return F.cross_entropy(y_hat, y)

    def unpack_logits(self, logits: Tensor):
        z_logits = logits[:, 0, :]
        if logits.shape[1] > 1:
            neighbours_logits = logits[:, 1:, :]
        else:
            neighbours_logits = None
        return z_logits, neighbours_logits

    def shared_step(self, batch, stage: str = "train"):
        x, y, _ = batch
        y_hat = self(x)
        z_logits, neighbours_logits = self.unpack_logits(y_hat)
        z_loss = self.loss_function(z_logits, y)

        neighbors_loss = 0.0
        if type(neighbours_logits) is Tensor:
            n_neigbours = neighbours_logits.shape[1]
            for k in range(n_neigbours):
                neighbors_loss += self.loss_function(neighbours_logits[:, k, :], y)
            neighbors_loss = (
                neighbors_loss / n_neigbours
            )  # nornalization to make it independent of n_neigbours

        loss = z_loss + neighbors_loss

        self.log(
            f"{stage}_loss",
            loss,
            sync_dist=True,
            # loader names are used instead
            add_dataloader_idx=False,
        )
        for k in self.top_k:
            accuracy_metric = getattr(self, f"{stage}_top_{k}_accuracy")
            # we use only z for accuracy
            accuracy_metric(F.softmax(z_logits, dim=-1), y)
            self.log(
                f"{stage}_top_{k}_accuracy",
                accuracy_metric,
                prog_bar=True,
                on_epoch=True,
                on_step=False,
                # loader names are used instead
                add_dataloader_idx=False,
            )
        return loss

    # Forward function outputs logits
    def forward(self, x):
        feats = self.backbone.forward(x, self.num_neighbors, self.use_identity)
        out = self.linear_classifier(feats)
        return out
