import math
from typing import List, Tuple

import torch
import pytorch_lightning as pl
from torch import Tensor
from torch.nn import Identity
from torchvision.models import resnet50, resnet18

from lightly.loss.ntx_ent_loss import NTXentLoss
from lightly.models.utils import get_weight_decay_parameters
from lightly.transforms import SimCLRTransform
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.lars import LARS
from lightly.utils.scheduler import CosineWarmupScheduler

import math
from typing import List, Optional
from typing import Union

import torch.nn.functional as F

from model_zoo.utils import mlp
from model_zoo.base import BaseModel


class SupervisedModel(BaseModel):
    def __init__(
        self,
        lr: Union[float, str] = "auto",
        momentum: float = 0.9,
        weight_decay: float = 1e-6,
        network: str = "resnet50",
        low_res: bool = False,
        weight_decay_trick: bool = True,
        datamodule: Optional[pl.LightningDataModule] = None,
    ) -> None:
        super().__init__(
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay,
            network=network,
            low_res=low_res,
            weight_decay_trick=weight_decay_trick,
            datamodule=datamodule,
        )
        self.criterion = torch.nn.CrossEntropyLoss()

    def load_modules(self):

        if self.network == "resnet18":
            resnet = resnet18()
            resnet.fc = torch.nn.Linear(512, self.num_classes)
        elif self.network == "resnet50":
            resnet = resnet50()
            resnet.fc = torch.nn.Linear(2048, self.num_classes)

        if self.low_res:
            input_channels = (
                self.datamodule.input_channels
                if hasattr(self.datamodule, "input_channels")
                else 3
            )
            resnet.conv1 = torch.nn.Conv2d(
                input_channels,
                64,
                kernel_size=(3, 3),
                stride=(1, 1),
                padding=(1, 1),
                bias=False,
            )
            resnet.maxpool = Identity()

        self.backbone = resnet

    def forward(self, x: Tensor) -> Tensor:
        return self.backbone(x)

    def training_step(
        self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
    ) -> Tensor:
        views, targets = batch[0], batch[1]
        num_views = len(views)
        images = torch.cat(views, dim=0)
        targets = targets.repeat(
            num_views
        )  # Duplicate targets to match the number of views
        outputs = self.forward(images)
        loss = self.criterion(outputs, targets)
        self.log(
            "train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets)
        )
        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        val_dataset_name = self.datamodule.val_dataset_names[dataloader_idx]
        images, targets = batch[0], batch[1]
        features = self.forward(images).flatten(start_dim=1)
        loss_online_probe = self.compute_online_probe(
            features, targets, val_dataset_name
        )
        return loss_online_probe

    def compute_online_probe(self, z: Tensor, y: Tensor, stage: str) -> Tensor:
        """No linear probe here, z is directly the prediction."""
        y = y.squeeze()
        loss_online_probe = F.cross_entropy(z, y)
        self.log(f"{stage}_online_linear_probe_loss", loss_online_probe, sync_dist=True)

        accuracy_metric = getattr(self, f"online_{stage}_accuracy")
        accuracy_metric(F.softmax(z, dim=-1), y)
        self.log(
            f"online_{stage}_accuracy",
            accuracy_metric,
            prog_bar=True,
            sync_dist=True,
            on_epoch=True,
            on_step=False,
        )
        return loss_online_probe

    def configure_optimizers(self):
        # Set learning rate.
        if self.lr == "auto":
            # Square root learning rate scaling improves performance for small
            # batch sizes (<=2048) and few training epochs (<=200).
            if (
                self.datamodule.batch_size * self.trainer.world_size <= 2048
                and self.trainer.max_epochs <= 200
            ):
                lr = 0.075 * math.sqrt(
                    self.datamodule.batch_size * self.trainer.world_size
                )
            # Alternatively, linear scaling can be used for larger batches and longer training.
            # See Appendix B.1. in the SimCLR paper https://arxiv.org/abs/2002.05709
            else:
                lr = 0.3 * self.datamodule.batch_size * self.trainer.world_size / 256
        else:
            lr = self.lr

        # Don't use weight decay for batch norm, bias parameters, and classification
        # head to improve performance (for imagenet).
        if self.weight_decay_trick:
            params, params_no_weight_decay = get_weight_decay_parameters(
                [self.backbone]
            )
            optimizer = LARS(
                [
                    {"name": "simclr", "params": params},
                    {
                        "name": "simclr_no_weight_decay",
                        "params": params_no_weight_decay,
                        "weight_decay": 0.0,
                    },
                ],
                lr=lr,
                momentum=self.momentum,
                weight_decay=self.weight_decay,
            )
        else:
            optimizer = LARS(
                self.parameters(),
                lr=lr,
                momentum=self.momentum,
                weight_decay=self.weight_decay,
            )

        scheduler = {
            "scheduler": CosineWarmupScheduler(
                optimizer=optimizer,
                warmup_epochs=int(
                    self.trainer.estimated_stepping_batches
                    / self.trainer.max_epochs
                    * 10
                ),
                max_epochs=int(self.trainer.estimated_stepping_batches),
            ),
            "interval": "step",
        }
        return [optimizer], [scheduler]
