import logging
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union

import hydra
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchmetrics
from torch.optim import Optimizer

from nn_core.model_logging import NNLogger

from model_merging.data.datamodule import MetaData
from model_merging.data.dataset import maybe_dictionarize
from model_merging.utils.utils import torch_load, torch_save

pylogger = logging.getLogger(__name__)


class ImageClassifier(pl.LightningModule):
    logger: NNLogger

    def __init__(
        self, encoder, classifier, metadata: Optional[MetaData] = None, *args, **kwargs
    ) -> None:
        super().__init__()

        # Populate self.hparams with args and kwargs automagically!
        # We want to skip metadata since it is saved separately by the NNCheckpointIO object.
        # Be careful when modifying this instruction. If in doubt, don't do it :]
        self.save_hyperparameters(logger=False, ignore=("metadata",))

        self.metadata = metadata
        self.num_classes = classifier.out_features

        metric = torchmetrics.Accuracy(
            task="multiclass", num_classes=self.num_classes, top_k=1
        )
        self.train_acc = metric.clone()
        self.val_acc = metric.clone()
        self.test_acc = metric.clone()

        self.encoder = encoder
        self.classification_head = classifier

        self.log_fn = lambda metric, val: self.log(
            metric, val, on_step=True, on_epoch=True
        )

        self.finetuning_accuracy = None

        # Mergeability regularization configuration
        self.pretrained_state_dict = None
        self.enable_moderate_update = False
        self.lambda_moderate_update = 0.0
        self.enable_grad_magnitude = False
        self.lambda_grad_magnitude = 0.0

    def set_encoder(self, encoder: torch.nn.Module):
        """Set the encoder of the model.

        Args:
            encoder (torch.nn.Module): The new encoder to set.
        """
        self.encoder = encoder

    def set_head(self, head: torch.nn.Module):
        """Set the classification head of the model.

        Args:
            head (torch.nn.Module): The new classification head to set.
        """
        self.classification_head = head

    def set_metrics(self, num_classes):

        self.num_classes = num_classes

        metric = torchmetrics.Accuracy(
            task="multiclass", num_classes=num_classes, top_k=1
        )

        self.train_acc = metric.clone()
        self.val_acc = metric.clone()
        self.test_acc = metric.clone()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """ """
        embeddings = self.encoder(x)

        logits = self.classification_head(embeddings)

        return logits

    def _step(self, batch: Dict[str, torch.Tensor], split: str) -> Mapping[str, Any]:
        batch = maybe_dictionarize(batch, self.hparams.x_key, self.hparams.y_key)

        x = batch[self.hparams.x_key]
        gt_y = batch[self.hparams.y_key]

        logits = self(x)

        loss = F.cross_entropy(logits, gt_y)
        preds = torch.softmax(logits, dim=-1)

        metrics = getattr(self, f"{split}_acc")
        metrics.update(preds, gt_y)

        self.log_fn(f"acc/{split}/{self.task_name}", metrics)
        self.log_fn(f"loss/{split}/{self.task_name}", loss)

        return {"logits": logits.detach(), "loss": loss}

    def training_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]:
        result = self._step(batch=batch, split="train")
        loss = result["loss"]

        # Add mergeability regularization terms
        if self.pretrained_state_dict is not None:
            reg_loss = 0.0

            # R2: Moderate Update Regularization
            # Penalize large deviations from pretrained weights
            if self.enable_moderate_update and self.lambda_moderate_update > 0:
                moderate_update_loss = 0.0
                for name, param in self.encoder.named_parameters():
                    if name in self.pretrained_state_dict and param.requires_grad:
                        pretrained_param = self.pretrained_state_dict[name].to(param.device)
                        moderate_update_loss += torch.sum((param - pretrained_param) ** 2)

                reg_loss += self.lambda_moderate_update * moderate_update_loss
                self.log_fn(f"reg/moderate_update/{self.task_name}", moderate_update_loss.detach())

            # R3: Gradient Magnitude Regularization
            # Encourage moderate gradient magnitudes
            if self.enable_grad_magnitude and self.lambda_grad_magnitude > 0:
                grad_magnitude_loss = 0.0
                for param in self.encoder.parameters():
                    if param.requires_grad and param.grad is not None:
                        grad_magnitude_loss += torch.sum(param.grad ** 2)

                # Only add if gradients exist (after first backward pass)
                if grad_magnitude_loss > 0:
                    reg_loss += self.lambda_grad_magnitude * grad_magnitude_loss
                    self.log_fn(f"reg/grad_magnitude/{self.task_name}", grad_magnitude_loss.detach())

            if reg_loss > 0:
                loss = loss + reg_loss
                self.log_fn(f"reg/total/{self.task_name}", reg_loss.detach())

        result["loss"] = loss
        return result

    def validation_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]:
        return self._step(batch=batch, split="val")

    def test_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]:
        return self._step(batch=batch, split="test")

    def freeze_head(self):
        self.classification_head.weight.requires_grad_(False)
        self.classification_head.bias.requires_grad_(False)

    def configure_optimizers(
        self,
    ) -> Union[Optimizer, Tuple[Sequence[Optimizer], Sequence[Any]]]:
        """Choose what optimizers and learning-rate schedulers to use in your optimization.

        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        Return:
            Any of these 6 options.
            - Single optimizer.
            - List or Tuple - List of optimizers.
            - Two lists - The first list has multiple optimizers, the second a list of LR schedulers (or lr_dict).
            - Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler'
              key whose value is a single LR scheduler or lr_dict.
            - Tuple of dictionaries as described, with an optional 'frequency' key.
            - None - Fit will run without any optimizer.
        """
        opt = hydra.utils.instantiate(self.hparams.optimizer, params=self.parameters())
        if "lr_scheduler" not in self.hparams:
            return [opt]
        scheduler = hydra.utils.instantiate(self.hparams.lr_scheduler, optimizer=opt)
        return [opt], [scheduler]

    def __call__(self, inputs):
        return self.forward(inputs)

    def save(self, filename):
        print(f"Saving image classifier to {filename}")
        torch_save(self, filename)

    @classmethod
    def load(cls, filename):
        print(f"Loading image classifier from {filename}")
        return torch_load(filename)

    def set_task(self, task_name):
        self.task_name = task_name

    def set_finetuning_accuracy(self, finetuning_accuracy):
        self.finetuning_accuracy = finetuning_accuracy

    def set_regularization_config(
        self,
        pretrained_state_dict: Optional[Dict[str, torch.Tensor]] = None,
        enable_moderate_update: bool = False,
        lambda_moderate_update: float = 0.0,
        enable_grad_magnitude: bool = False,
        lambda_grad_magnitude: float = 0.0,
    ):
        """Configure mergeability regularization.

        Args:
            pretrained_state_dict: State dict of pretrained encoder weights
            enable_moderate_update: Enable R2 (moderate update) regularization
            lambda_moderate_update: Weight for R2 regularization
            enable_grad_magnitude: Enable R3 (gradient magnitude) regularization
            lambda_grad_magnitude: Weight for R3 regularization
        """
        # Store as CPU tensors, will be moved to device during training
        self.pretrained_state_dict = pretrained_state_dict
        self.enable_moderate_update = enable_moderate_update
        self.lambda_moderate_update = lambda_moderate_update
        self.enable_grad_magnitude = enable_grad_magnitude
        self.lambda_grad_magnitude = lambda_grad_magnitude

    def on_test_epoch_end(self):
        # Compute and log the test accuracy
        accuracy = self.test_acc.compute().cpu().item()
        self.log(f"acc/test/{self.task_name}", accuracy, on_step=False, on_epoch=True)

        if self.finetuning_accuracy is not None:
            normalized_acc = accuracy / self.finetuning_accuracy
            self.log(f"normalized_acc/test/{self.task_name}", normalized_acc, on_step=False, on_epoch=True)
