# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import logging
from typing import Any, Callable, Dict, List, Tuple, Union

import lightning.pytorch as pl
import omegaconf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ExponentialLR, MultiStepLR, ReduceLROnPlateau

from solo.utils.lars import LARS
from solo.utils.lr_scheduler import LinearWarmupCosineAnnealingLR
from solo.utils.metrics import accuracy_at_k, weighted_mean, multi_label_metrics, compute_balanced_accuracy
from solo.utils.misc import (
    omegaconf_select,
    param_groups_layer_decay,
    remove_bias_and_norm_from_weight_decay,
)

# import radial loss
from solo.losses.radialvicreg import chi2_radial_nll_loss_for_lightning_logging

class LinearModel(pl.LightningModule):
    _OPTIMIZERS = {
        "sgd": torch.optim.SGD,
        "lars": LARS,
        "adam": torch.optim.Adam,
        "adamw": torch.optim.AdamW,
    }
    _SCHEDULERS = [
        "reduce",
        "warmup_cosine",
        "step",
        "exponential",
        "none",
    ]

    def __init__(
        self,
        backbone: nn.Module,
        cfg: omegaconf.DictConfig,
        loss_func: Callable = None,
        mixup_func: Callable = None,
    ):
        """Implements linear and finetune evaluation.

        .. note:: Cfg defaults are set in init by calling `cfg = add_and_assert_specific_cfg(cfg)`

        backbone (nn.Module): backbone architecture for feature extraction.
        Cfg basic structure:
            data:
                num_classes (int): number of classes in the dataset.
            max_epochs (int): total number of epochs.

            optimizer:
                name (str): name of the optimizer.
                batch_size (int): number of samples in the batch.
                lr (float): learning rate.
                weight_decay (float): weight decay for optimizer.
                kwargs (Dict): extra named arguments for the optimizer.
            scheduler:
                name (str): name of the scheduler.
                min_lr (float): minimum learning rate for warmup scheduler. Defaults to 0.0.
                warmup_start_lr (float): initial learning rate for warmup scheduler.
                    Defaults to 0.00003.
                warmup_epochs (float): number of warmup epochs. Defaults to 10.
                lr_decay_steps (Sequence, optional): steps to decay the learning rate
                    if scheduler is step. Defaults to None.
                interval (str): interval to update the lr scheduler. Defaults to 'step'.

            finetune (bool): whether or not to finetune the backbone. Defaults to False.

            performance:
                disable_channel_last (bool). Disables channel last conversion operation which
                speeds up training considerably. Defaults to False.
                https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html#converting-existing-models

        loss_func (Callable): loss function to use (for mixup, label smoothing or default).
        Defaults to None mixup_func (Callable, optional). function to convert data and targets
        with mixup/cutmix. Defaults to None.
        """

        super().__init__()

        # add default values and assert that config has the basic needed settings
        cfg = self.add_and_assert_specific_cfg(cfg)

        # backbone
        self.backbone = backbone
        if hasattr(self.backbone, "inplanes"):
            features_dim = self.backbone.inplanes
        else:
            features_dim = self.backbone.num_features

        # classifier
        self.classifier = nn.Linear(features_dim, cfg.data.num_classes)  # type: ignore

        # mixup/cutmix function
        self.mixup_func: Callable = mixup_func

        self.dataset_name = cfg.data.dataset
        self.multilabel_datasets = ["CelebA"]
        self.is_multilabel = self.dataset_name in self.multilabel_datasets

        if self.is_multilabel:
            self.loss_func = nn.BCEWithLogitsLoss()
            if self.dataset_name == "CelebA":
                self.dataset_attr_names = [
                    "5_o_Clock_Shadow", "Arched_Eyebrows", "Attractive", "Bags_Under_Eyes",
                    "Bald", "Bangs", "Big_Lips", "Big_Nose", "Black_Hair", "Blond_Hair",
                    "Blurry", "Brown_Hair", "Bushy_Eyebrows", "Chubby", "Double_Chin",
                    "Eyeglasses", "Goatee", "Gray_Hair", "Heavy_Makeup", "High_Cheekbones",
                    "Male", "Mouth_Slightly_Open", "Mustache", "Narrow_Eyes", "No_Beard",
                    "Oval_Face", "Pale_Skin", "Pointy_Nose", "Receding_Hairline",
                    "Rosy_Cheeks", "Sideburns", "Smiling", "Straight_Hair", "Wavy_Hair",
                    "Wearing_Earrings", "Wearing_Hat", "Wearing_Lipstick",
                    "Wearing_Necklace", "Wearing_Necktie", "Young"
                ]
            else:
                self.dataset_attr_names = [str(i) for i in range(cfg.data.num_classes)] # for other like iNaturlist
        elif loss_func is None:
            self.loss_func = nn.CrossEntropyLoss()
        else:
            self.loss_func = loss_funcs

        # training related
        self.max_epochs: int = cfg.max_epochs
        self.accumulate_grad_batches: Union[int, None] = cfg.accumulate_grad_batches

        # optimizer related
        self.optimizer: str = cfg.optimizer.name
        self.batch_size: int = cfg.optimizer.batch_size
        self.lr: float = cfg.optimizer.lr
        self.weight_decay: float = cfg.optimizer.weight_decay
        self.extra_optimizer_args: Dict[str, Any] = cfg.optimizer.kwargs
        self.exclude_bias_n_norm_wd: bool = cfg.optimizer.exclude_bias_n_norm_wd
        self.layer_decay: float = cfg.optimizer.layer_decay

        # scheduler related
        self.scheduler: str = cfg.scheduler.name
        self.lr_decay_steps: Union[List[int], None] = cfg.scheduler.lr_decay_steps
        self.min_lr: float = cfg.scheduler.min_lr
        self.warmup_start_lr: float = cfg.scheduler.warmup_start_lr
        self.warmup_epochs: int = cfg.scheduler.warmup_epochs
        self.scheduler_interval: str = cfg.scheduler.interval
        assert self.scheduler_interval in ["step", "epoch"]
        if self.scheduler_interval == "step":
            logging.warn(
                f"Using scheduler_interval={self.scheduler_interval} might generate "
                "issues when resuming a checkpoint."
            )

        # if finetuning the backbone
        self.finetune: bool = cfg.finetune

        # for performance
        self.no_channel_last = cfg.performance.disable_channel_last

        if not self.finetune:
            for param in self.backbone.parameters():
                param.requires_grad = False

        # keep track of validation metrics
        self.validation_step_outputs = []

        # keep track of training metrics
        self.training_step_outputs = {
            "curr_radial_loss": 0.0,
            "curr_feats_l2_norm_mean": 0.0,
            "curr_feats_l2_norm_var": 0.0,

            "curr_radial_loss_before_relu": 0.0,
            "curr_feats_l2_norm_mean_before_relu": 0.0,
            "curr_feats_l2_norm_var_before_relu": 0.0,

            "num_training_examples": 0,
        }

        # Balanced-accuracy support for multi-label datasets (CelebA).
        if self.is_multilabel:
            # Place-holders that will be reset every epoch in on_validation_epoch_start
            self.validation_balanced_acc_buffers = None  # will hold running TP / TN / FP / FN sums

    @staticmethod
    def add_and_assert_specific_cfg(cfg: omegaconf.DictConfig) -> omegaconf.DictConfig:
        """Adds method specific default values/checks for config.

        Args:
            cfg (omegaconf.DictConfig): DictConfig object.

        Returns:
            omegaconf.DictConfig: same as the argument, used to avoid errors.
        """

        # default parameters for optimizer
        cfg.optimizer.exclude_bias_n_norm_wd = omegaconf_select(
            cfg, "optimizer.exclude_bias_n_norm_wd", False
        )
        # default for extra optimizer kwargs (use pytorch's default if not available)
        cfg.optimizer.kwargs = omegaconf_select(cfg, "optimizer.kwargs", {})
        cfg.optimizer.layer_decay = omegaconf_select(cfg, "optimizer.layer_decay", 0.0)

        # whether or not to finetune the backbone
        cfg.finetune = omegaconf_select(cfg, "finetune", False)

        # default for acc grad batches
        cfg.accumulate_grad_batches = omegaconf_select(cfg, "accumulate_grad_batches", 1)

        # default parameters for the scheduler
        cfg.scheduler.lr_decay_steps = omegaconf_select(cfg, "scheduler.lr_decay_steps", None)
        cfg.scheduler.min_lr = omegaconf_select(cfg, "scheduler.min_lr", 0.0)
        cfg.scheduler.warmup_start_lr = omegaconf_select(cfg, "scheduler.warmup_start_lr", 3e-5)
        cfg.scheduler.warmup_epochs = omegaconf_select(cfg, "scheduler.warmup_epochs", 10)
        cfg.scheduler.interval = omegaconf_select(cfg, "scheduler.interval", "step")

        # default parameters for performance optimization
        cfg.performance = omegaconf_select(cfg, "performance", {})
        cfg.performance.disable_channel_last = omegaconf_select(
            cfg, "performance.disable_channel_last", False
        )

        return cfg

    def configure_optimizers(self) -> Tuple[List, List]:
        """Collects learnable parameters and configures the optimizer and learning rate scheduler.

        Returns:
            Tuple[List, List]: two lists containing the optimizer and the scheduler.
        """

        if self.layer_decay > 0:
            assert self.finetune, "Only with use layer weight decay with finetune on."
            msg = (
                "Method should implement no_weight_decay() that returns "
                "a set of parameter names to ignore from weight decay"
            )
            assert hasattr(self.backbone, "no_weight_decay"), msg

            learnable_params = param_groups_layer_decay(
                self.backbone,
                self.weight_decay,
                no_weight_decay_list=self.backbone.no_weight_decay(),
                layer_decay=self.layer_decay,
            )
            learnable_params.append({"name": "classifier", "params": self.classifier.parameters()})
        else:
            learnable_params = (
                self.classifier.parameters()
                if not self.finetune
                else [
                    {"name": "backbone", "params": self.backbone.parameters()},
                    {"name": "classifier", "params": self.classifier.parameters()},
                ]
            )

        # exclude bias and norm from weight decay
        if self.exclude_bias_n_norm_wd:
            learnable_params = remove_bias_and_norm_from_weight_decay(learnable_params)

        assert self.optimizer in self._OPTIMIZERS
        optimizer = self._OPTIMIZERS[self.optimizer]

        optimizer = optimizer(
            learnable_params,
            lr=self.lr,
            weight_decay=self.weight_decay,
            **self.extra_optimizer_args,
        )

        # select scheduler
        if self.scheduler == "none":
            return optimizer

        if self.scheduler == "warmup_cosine":
            max_warmup_steps = (
                self.warmup_epochs * (self.trainer.estimated_stepping_batches / self.max_epochs)
                if self.scheduler_interval == "step"
                else self.warmup_epochs
            )
            max_scheduler_steps = (
                self.trainer.estimated_stepping_batches
                if self.scheduler_interval == "step"
                else self.max_epochs
            )
            scheduler = {
                "scheduler": LinearWarmupCosineAnnealingLR(
                    optimizer,
                    warmup_epochs=max_warmup_steps,
                    max_epochs=max_scheduler_steps,
                    warmup_start_lr=self.warmup_start_lr if self.warmup_epochs > 0 else self.lr,
                    eta_min=self.min_lr,
                ),
                "interval": self.scheduler_interval,
                "frequency": 1,
            }
        elif self.scheduler == "reduce":
            scheduler = ReduceLROnPlateau(optimizer)
        elif self.scheduler == "step":
            scheduler = MultiStepLR(optimizer, self.lr_decay_steps, gamma=0.1)
        elif self.scheduler == "exponential":
            scheduler = ExponentialLR(optimizer, self.weight_decay)
        else:
            raise ValueError(
                f"{self.scheduler} not in (warmup_cosine, cosine, reduce, step, exponential)"
            )

        return [optimizer], [scheduler]

    # def on_train_epoch_start(self):
    #     self.near_zero_count = 0
    #     self.total_count = 0

    # def on_validation_epoch_start(self):
    #     self.near_zero_count = 0
    #     self.total_count = 0

    # def update_sparsity_counters(self, embeddings, threshold=1e-5):
    #     near_zero = (embeddings.abs() < threshold)

    #     self.near_zero_count += near_zero.sum().item()
    #     self.total_count += embeddings.numel()

    #     self.current_ratio = (self.near_zero_count / self.total_count) * 100
    #     self.log("sparsity/current_ratio", self.current_ratio, on_epoch=True, sync_dist=True)

    def forward(self, X: torch.tensor) -> Dict[str, Any]:
        """Performs forward pass of the frozen backbone and the linear layer for evaluation.

        Args:
            X (torch.tensor): a batch of images in the tensor format.

        Returns:
            Dict[str, Any]: a dict containing features and logits.
        """

        if not self.no_channel_last:
            X = X.to(memory_format=torch.channels_last)

        with torch.set_grad_enabled(self.finetune):
            backbone_out = self.backbone(X)
            if isinstance(backbone_out, tuple):
                feats, feats_before_relu = backbone_out
            else:
                feats = backbone_out
                feats_before_relu = None

            batch_size = feats.shape[0]
                
        # log radial loss over the entire training set for feats
        if self.training:
            with torch.no_grad():
                if feats_before_relu is not None: # adding this for backward compatibility with ReLU logging changes
                    curr_radial_loss_before_relu = chi2_radial_nll_loss_for_lightning_logging(
                        feats_before_relu
                    ).item()
                    curr_feats_l2_norm_mean_before_relu = (
                        torch.norm(feats_before_relu, dim=1).mean().item()
                    )
                    curr_feats_l2_norm_var_before_relu = (
                        torch.norm(feats_before_relu, dim=1).var().item()
                    )
                    self.training_step_outputs[
                        "curr_radial_loss_before_relu"
                    ] += curr_radial_loss_before_relu
                    self.training_step_outputs[
                        "curr_feats_l2_norm_mean_before_relu"
                    ] += curr_feats_l2_norm_mean_before_relu
                    self.training_step_outputs[
                        "curr_feats_l2_norm_var_before_relu"
                    ] += curr_feats_l2_norm_var_before_relu
                    self.log(
                        "linear_eval/curr_radial_loss_before_relu",
                        (
                            self.training_step_outputs["curr_radial_loss_before_relu"]
                            / self.training_step_outputs["num_training_examples"]
                        )
                        * batch_size,
                        on_epoch=True,
                    )
                    self.log(
                        "linear_eval/curr_feats_l2_norm_mean_before_relu",
                        (
                            self.training_step_outputs["curr_feats_l2_norm_mean_before_relu"]
                            / self.training_step_outputs["num_training_examples"]
                        )
                        * batch_size,
                        on_epoch=True,
                    )
                    self.log(
                        "linear_eval/curr_feats_l2_norm_var_before_relu",
                        (
                            self.training_step_outputs["curr_feats_l2_norm_var_before_relu"]
                            / self.training_step_outputs["num_training_examples"]
                        )
                        * batch_size,
                        on_epoch=True,
                    )

                curr_radial_loss = chi2_radial_nll_loss_for_lightning_logging(feats).item()
                curr_feats_l2_norm_mean = torch.norm(feats, dim=1).mean().item()
                curr_feats_l2_norm_var = torch.norm(feats, dim=1).var().item()
            
            self.training_step_outputs['curr_radial_loss'] += curr_radial_loss
            self.training_step_outputs['curr_feats_l2_norm_mean'] += curr_feats_l2_norm_mean
            self.training_step_outputs['curr_feats_l2_norm_var'] += curr_feats_l2_norm_var

            self.training_step_outputs['num_training_examples'] += batch_size

            self.log("linear_eval/curr_radial_loss", (self.training_step_outputs['curr_radial_loss'] / self.training_step_outputs['num_training_examples']) * batch_size, on_epoch=True)
            self.log("linear_eval/curr_feats_l2_norm_mean", (self.training_step_outputs['curr_feats_l2_norm_mean'] / self.training_step_outputs['num_training_examples']) * batch_size, on_epoch=True)
            self.log("linear_eval/curr_feats_l2_norm_var", (self.training_step_outputs['curr_feats_l2_norm_var'] / self.training_step_outputs['num_training_examples']) * batch_size, on_epoch=True)

        logits = self.classifier(feats)
        return {"logits": logits, "feats": feats}

    def shared_step(self, batch: Tuple, batch_idx: int):
        """Helper function that performs the forward pass and loss calculation for a batch.

        Args:
            batch (Tuple): a batch of data in the format of [X, Y].
            batch_idx (int): index of the batch.
        """
        if self.is_multilabel:
            X, targets = batch
            # forward
            logits = self.forward(X)["logits"]
            # loss
            loss = self.loss_func(logits, targets.float())

            # metrics
            exact_match, hamming, jaccard, separate_metrics = multi_label_metrics(
                logits, targets, self.dataset_attr_names, nn_type="linear"
            )

            results = {
                "loss": loss,
                "exact_match": exact_match,
                "hamming": hamming,
                "jaccard": jaccard,
                # propagate TP / FN / TN / FP for balanced-accuracy accumulation
                **separate_metrics,
            }
            return results
        else:
            X, targets = batch
            # forward
            out = self(X)
            logits = out["logits"]
            # loss
            loss = F.cross_entropy(logits, targets)
            # metrics
            acc1, acc5 = accuracy_at_k(logits, targets, top_k=(1, 5))
            results = {
                "loss": loss,
                "acc1": acc1,
                "acc5": acc5,
            }
            return results

    def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        """Performs the training step for the linear eval.

        Args:
            batch (torch.Tensor): a batch of images in the tensor format.
            batch_idx (int): the index of the batch.

        Returns:
            torch.Tensor: cross-entropy loss between the predictions and the ground truth.
        """

        # set backbone to eval mode
        if not self.finetune:
            self.backbone.eval()

        out = self.shared_step(batch, batch_idx)

        if self.is_multilabel:
            log = {
                "train_loss": out["loss"],
                "train_exact_match": out["exact_match"],
                "train_hamming": out["hamming"],
                "train_jaccard": out["jaccard"],
            }
        else:
            log = {"train_loss": out["loss"]}
            if self.mixup_func is None:
                log["train_acc1"] = out["acc1"]
                log["train_acc5"] = out["acc5"]

        self.log_dict(log, on_epoch=True, sync_dist=True)
        return out["loss"]

    def validation_step(self, batch: torch.Tensor, batch_idx: int) -> Dict[str, Any]:
        """Performs the validation step for the linear eval."""

        X, _ = batch
        batch_size = X.size(0)

        out = self.shared_step(batch, batch_idx)

        # accumulate balanced-accuracy components for multi-label datasets
        if self.is_multilabel:
            for quantity in ["tps", "fns", "tns", "fps"]:
                key = f"linear_{quantity}"
                if key in out:
                    self.validation_balanced_acc_buffers[quantity] += out[key]

        out["batch_size"] = batch_size
        self.validation_step_outputs.append(out)
        return out

    def on_train_epoch_end(self):
        # clear output
        self.training_step_outputs.clear()
        self.training_step_outputs = {
            "curr_radial_loss": 0.0,
            "curr_feats_l2_norm_mean": 0.0,
            "curr_feats_l2_norm_var": 0.0,

            "curr_radial_loss_before_relu": 0.0,
            "curr_feats_l2_norm_mean_before_relu": 0.0,
            "curr_feats_l2_norm_var_before_relu": 0.0,

            "num_training_examples": 0,
        }

    def on_validation_epoch_end(self):
        """Averages the losses and accuracies of all the validation batches."""

        if self.is_multilabel:
            val_loss = weighted_mean(self.validation_step_outputs, "loss", "batch_size")
            val_exact_match = weighted_mean(
                self.validation_step_outputs, "exact_match", "batch_size"
            )
            val_hamming = weighted_mean(self.validation_step_outputs, "hamming", "batch_size")
            val_jaccard = weighted_mean(self.validation_step_outputs, "jaccard", "batch_size")
            log = {
                "val_loss": val_loss,
                "val_exact_match": val_exact_match,
                "val_hamming": val_hamming,
                "val_jaccard": val_jaccard,
            }
        else:
            val_loss = weighted_mean(self.validation_step_outputs, "loss", "batch_size")
            val_acc1 = weighted_mean(self.validation_step_outputs, "acc1", "batch_size")
            val_acc5 = weighted_mean(self.validation_step_outputs, "acc5", "batch_size")
            log = {"val_loss": val_loss, "val_acc1": val_acc1, "val_acc5": val_acc5}

        # ----- Balanced accuracy logging for multi-label -----
        if self.is_multilabel and self.validation_balanced_acc_buffers is not None:
            bal_acc = compute_balanced_accuracy(
                self.validation_balanced_acc_buffers["tps"],
                self.validation_balanced_acc_buffers["fns"],
                self.validation_balanced_acc_buffers["tns"],
                self.validation_balanced_acc_buffers["fps"],
            )

            # log per-attribute and averaged BA
            bal_acc_dict = dict(
                zip(
                    [f"val_separate_metrics/linear_{k}" for k in self.dataset_attr_names],
                    bal_acc,
                )
            )
            log.update({"val_balanced_acc": bal_acc.mean()}) # log the mean of the balanced accuracy - could also log every attribute but would be harder to parse
            log.update(bal_acc_dict)

        self.log_dict(log, sync_dist=True)

        # clear outputs
        self.validation_step_outputs.clear()

        # clear BA buffers
        if self.is_multilabel:
            self.validation_balanced_acc_buffers = None

        # # sparsity
        # if self.total_count > 0:
        #     global_sparsity = (self.near_zero_count / self.total_count) * 100
        # else:
        #     global_sparsity = 0.0

        # self.log("val/global_sparsity", global_sparsity, on_epoch=True)

    def on_validation_epoch_start(self):
        """Reset running sums for balanced-accuracy (multi-label)."""
        if self.is_multilabel:
            device = self.device
            num_attr = len(self.dataset_attr_names)
            self.validation_balanced_acc_buffers = {
                "tps": torch.zeros(num_attr, device=device), #tps = true positives
                "fns": torch.zeros(num_attr, device=device), #fns = false negatives
                "tns": torch.zeros(num_attr, device=device), #tns = true negatives
                "fps": torch.zeros(num_attr, device=device), #fps = false positives
            }
