import pytorch_lightning as pl
import torch.nn as nn
import torch
import torch.nn.functional as F
import torchmetrics
from models.base_model import BaseModel
from typing import Optional


class LinearFineTuner(BaseModel):
    """Trains a linear classifier head based on given self-supervised model.

    Args:
        ssl_model: backbone to use for representation
        freeze: if true only the linear head is trained
    """

    def __init__(
        self,
        embedding_model: pl.LightningModule = None,
        freeze: bool = True,
        learning_rate: float = 1e-1,
        optimizer: str = "sgd",
        weight_decay: float = 1e-6,
        scheduler_type: str = "cosine",
        final_lr: float = 0.0,
        num_classes: int = 10,
        epochs: int = 100,
        datamodule: Optional[pl.LightningDataModule] = None,
    ):
        super().__init__()

        if freeze:
            embedding_model.freeze()
        self.embedding_model = embedding_model
        self.freeze = freeze
        self.num_classes = num_classes
        self.epochs = epochs
        self.datamodule = datamodule

        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.weight_decay = weight_decay
        self.scheduler_type = scheduler_type
        self.final_lr = final_lr

        self.embedding_dim = self.embedding_model.embedding_dim
        self.linear = nn.Linear(self.embedding_dim, self.num_classes, bias=True)

        self.train_accuracy = torchmetrics.Accuracy()
        self.val_accuracy = torchmetrics.Accuracy()
        self.test_accuracy = torchmetrics.Accuracy()

    def forward(self, x):
        z = self.embedding_model(x)
        preds = self.linear(z)
        return F.softmax(preds, dim=-1)

    def on_train_epoch_start(self) -> None:
        super().on_train_epoch_start()
        if self.freeze:
            self.embedding_model.eval()

    def training_step(self, batch, batch_idx):
        x, y = self.unpack_batch(batch)
        z = self.embedding_model(x)
        preds = self.linear(z)

        loss = F.cross_entropy(preds, y)
        self.train_accuracy(F.softmax(preds, dim=-1), y)
        self.log(f"train_loss_{self.model_name}", loss)
        self.log(
            "train_accuracy",
            self.train_accuracy,
            sync_dist=True,
            on_step=False,
            on_epoch=True,
        )
        return loss

    def unpack_batch(self, batch):
        """Allows model to work with both video and tensors."""
        if isinstance(batch, dict) and "video" in batch:
            return batch["video"], batch["label"]
        return batch[0], batch[1]

    def validation_step(self, batch, batch_idx, stage="val"):
        x, y = self.unpack_batch(batch)
        z = self.embedding_model(x)
        preds = self.linear(z)

        loss = F.cross_entropy(preds, y)
        acc_metric = getattr(self, f"{stage}_accuracy")
        acc_metric(F.softmax(preds, dim=-1), y)
        self.log(f"{stage}_loss_{self.model_name}", loss)
        self.log(
            f"{stage}_accuracy",
            acc_metric,
            sync_dist=True,
            on_step=False,
            on_epoch=True,
        )
        return loss

    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx, stage="test")

    def configure_optimizers(self):
        """
        Setup the Adam optimizer.
        """
        parameters = self.linear.parameters() if self.freeze else self.parameters()
        if self.optimizer == "adam":
            return torch.optim.Adam(
                parameters, lr=self.learning_rate, weight_decay=self.weight_decay
            )
        elif self.optimizer == "sgd":
            return self.sgd(parameters)
        raise ValueError(f"optimier {self.optimizer} not implemented")

    def sgd(self, parameters):
        optimizer = torch.optim.SGD(
            parameters,
            lr=self.learning_rate,
            nesterov=False,
            momentum=0.9,
            weight_decay=self.weight_decay,
        )

        if self.scheduler_type == "cosine":
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, self.epochs, eta_min=self.final_lr  # total epochs to run
            )
        return [optimizer], [scheduler]


# class SimCLRLinearFineTuner(LinearFineTuner):
#     def __init__(self, *args, **kwargs):
#         super().__init__(*args, **kwargs)

#     def unpack_batch(self, batch):
#         """Unpack the x used for online evaluation and ignore views"""
#         # (x1, x2, x): first two are augmented views
#         (_, _, x), y = batch
#         return x, y
