# This file is copied from ProTrek
# Original license: MIT License
import abc
import copy
import os

import pytorch_lightning as pl
import torch
from torch import distributed as dist

from .lr_scheduler import *


class AbstractModel(pl.LightningModule):
    def __init__(
        self,
        lr_scheduler_kwargs: dict = None,
        optimizer_kwargs: dict = None,
        save_path: str = None,
        from_checkpoint: str = None,
        load_prev_scheduler: bool = False,
        save_weights_only: bool = True,
    ):
        """

        Args:
            lr_scheduler: Kwargs for lr_scheduler
            optimizer_kwargs: Kwargs for optimizer_kwargs
            save_path: Save trained model
            from_checkpoint: Load model from checkpoint
            load_prev_scheduler: Whether load previous scheduler from checkpoint
            load_strict: Whether load model strictly
            save_weights_only: Whether save only weights or also optimizer and lr_scheduler

        """
        super().__init__()
        self.initialize_model()

        self.metrics = {}
        for stage in ["train", "valid", "test"]:
            stage_metrics = self.initialize_metrics(stage)
            # Rigister metrics as attributes
            for metric_name, metric in stage_metrics.items():
                setattr(self, metric_name, metric)

            self.metrics[stage] = stage_metrics

        if lr_scheduler_kwargs is None:
            # Default lr_scheduler
            self.lr_scheduler_kwargs = {
                "class": "ConstantLRScheduler",
                "init_lr": 0,
            }
            print(
                "No lr_scheduler_kwargs provided. The default learning rate is 0."
            )

        else:
            self.lr_scheduler_kwargs = lr_scheduler_kwargs

        if optimizer_kwargs is None:
            # Default optimizer
            self.optimizer_kwargs = {
                "class": "AdamW",
                "betas": (0.9, 0.98),
                "weight_decay": 0.01,
            }
            print(
                "No optimizer_kwargs provided. The default optimizer is AdamW."
            )
        else:
            self.optimizer_kwargs = optimizer_kwargs
        self.init_optimizers()

        self.save_path = save_path
        self.save_weights_only = save_weights_only

        # temp_step is used for accumulating gradients
        self.temp_step = 0
        self.step = 0
        self.epoch = 0

        self.load_prev_scheduler = load_prev_scheduler
        self.from_checkpoint = from_checkpoint
        if from_checkpoint:
            self.load_checkpoint(from_checkpoint)

    @abc.abstractmethod
    def initialize_model(self) -> None:
        """
        All model initialization should be done here
        Note that the whole model must be named as "self.model" for model saving and loading
        """
        raise NotImplementedError

    @abc.abstractmethod
    def forward(self, *args, **kwargs):
        """
        Forward propagation
        """
        raise NotImplementedError

    @abc.abstractmethod
    def initialize_metrics(self, stage: str) -> dict:
        """
        Initialize metrics for each stage
        Args:
            stage: "train", "valid" or "test"

        Returns:
            A dictionary of metrics for the stage. Keys are metric names and values are metric objects
        """
        raise NotImplementedError

    @abc.abstractmethod
    def loss_func(self, stage: str, outputs, labels) -> torch.Tensor:
        """

        Args:
            stage: "train", "valid" or "test"
            outputs: model outputs for calculating loss
            labels: labels for calculating loss

        Returns:
            loss

        """
        raise NotImplementedError

    @staticmethod
    def load_weights(model, weights):
        model_dict = model.state_dict()

        unused_params = []
        missed_params = list(model_dict.keys())

        for k, v in weights.items():
            if k in model_dict.keys():
                model_dict[k] = v
                missed_params.remove(k)

            else:
                unused_params.append(k)

        if len(missed_params) > 0:
            print(
                f"\033[31mSome weights of {type(model).__name__} were not "
                f"initialized from the model checkpoint: {missed_params}\033[0m"
            )

        if len(unused_params) > 0:
            print(
                f"\033[31mSome weights of the model checkpoint were not used: {unused_params}\033[0m"
            )

        model.load_state_dict(model_dict)

    def optimizer_step(
        self,
        epoch: int,
        batch_idx: int,
        optimizer,
        optimizer_closure=None,
    ) -> None:
        super().optimizer_step(epoch, batch_idx, optimizer, optimizer_closure)

        self.temp_step += 1
        if self.temp_step == self.trainer.accumulate_grad_batches:
            self.step += 1
            self.temp_step = 0

    # For pytorch-lightning 1.9.5
    # def optimizer_step(
    #     self,
    #     epoch: int,
    #     batch_idx: int,
    #     optimizer,
    #     optimizer_idx: int = 0,
    #     optimizer_closure=None,
    #     on_tpu: bool = False,
    #     using_native_amp: bool = False,
    #     using_lbfgs: bool = False,
    # ) -> None:
    #     super().optimizer_step(
    #         epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs
    #     )
    #     self.temp_step += 1
    #     if self.temp_step == self.trainer.accumulate_grad_batches:
    #         self.step += 1
    #         self.temp_step = 0

    def on_train_epoch_end(self):
        self.epoch += 1

    def training_step(self, batch, batch_idx):
        inputs, labels = batch

        # optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-4, weight_decay=0.01, betas=(0.9, 0.98))
        # for _ in range(1000):
        #     outputs = self(**inputs)
        #     loss = self.loss_func('train', outputs, labels)
        #     loss.backward()
        #     optimizer.step()
        #     optimizer.zero_grad()
        #
        # raise

        outputs = self(**inputs)
        loss = self.loss_func("train", outputs, labels)

        self.log("loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(**inputs)
        loss = self.loss_func("valid", outputs, labels)
        self.valid_outputs.append(loss)
        return loss

    def test_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(**inputs)

        loss = self.loss_func("test", outputs, labels)
        self.test_outputs.append(loss)
        return loss

    def on_train_start(self) -> None:
        # Load previous scheduler
        if getattr(self, "prev_schechuler", None) is not None:
            try:
                self.step = self.prev_schechuler["global_step"]
                self.epoch = self.prev_schechuler["epoch"]
                self.best_value = self.prev_schechuler["best_value"]
                self.lr_scheduler.load_state_dict(
                    self.prev_schechuler["lr_scheduler"]
                )
                print(f"Previous training global step: {self.step}")
                print(f"Previous training epoch: {self.epoch}")
                print(f"Previous best value: {self.best_value}")
                print(
                    f"Previous lr_scheduler: {self.prev_schechuler['lr_scheduler']}"
                )

                # Load optimizer state
                if hasattr(self.trainer.strategy, "deepspeed_engine"):
                    # For DeepSpeed strategy
                    try:
                        self.trainer.strategy.deepspeed_engine.load_checkpoint(
                            self.from_checkpoint
                        )
                    except Exception as e:
                        print(e)

                else:
                    # For DDP strategy
                    self.optimizer.load_state_dict(
                        self.prev_schechuler["optimizer"]
                    )

            except Exception as e:
                print(e)
                raise Exception(
                    "Error in loading previous scheduler. Please set load_prev_scheduler=False"
                )

    def on_validation_epoch_start(self) -> None:
        setattr(self, "valid_outputs", [])

    def on_test_epoch_start(self) -> None:
        setattr(self, "test_outputs", [])

    def load_checkpoint(self, from_checkpoint: str) -> None:
        """
        Args:
            from_checkpoint:  Path to checkpoint.
        """

        # If ``from_checkpoint`` is a directory, load the checkpoint in it
        if os.path.isdir(from_checkpoint):
            basename = os.path.basename(from_checkpoint)
            from_checkpoint = os.path.join(from_checkpoint, f"{basename}.pt")

        state_dict = torch.load(from_checkpoint, map_location=self.device)
        self.load_weights(self.model, state_dict["model"])

        if self.load_prev_scheduler:
            state_dict.pop("model")
            self.prev_schechuler = state_dict

    def save_checkpoint(
        self,
        save_path: str,
        save_info: dict = None,
        save_weights_only: bool = True,
    ) -> None:
        """
        Save model to save_path
        Args:
            save_path: Path to save model
            save_info: Other info to save
            save_weights_only: Whether only save model weights
        """
        dir = os.path.dirname(save_path)
        os.makedirs(dir, exist_ok=True)

        state_dict = {} if save_info is None else save_info
        state_dict["model"] = self.model.state_dict()

        # Convert model weights to fp32
        for k, v in state_dict["model"].items():
            state_dict["model"][k] = v.float()

        if not save_weights_only:
            state_dict["global_step"] = self.step
            state_dict["epoch"] = self.epoch
            state_dict["best_value"] = getattr(self, "best_value", None)
            state_dict["lr_scheduler"] = self.lr_schedulers().state_dict()

            # If not using DeepSpeed, save optimizer state
            if not hasattr(self.trainer.strategy, "deepspeed_engine"):
                state_dict["optimizer"] = (
                    self.optimizers().optimizer.state_dict()
                )

        torch.save(state_dict, save_path)

    def check_save_condition(
        self, now_value: float, mode: str, save_info: dict = None
    ) -> None:
        """
        Check whether to save model. If save_path is not None and now_value is the best, save model.
        Args:
            now_value: Current metric value
            mode: "min" or "max", meaning whether the lower the better or the higher the better
            save_info: Other info to save
        """

        assert mode in ["min", "max"], "mode should be 'min' or 'max'"

        if self.save_path is not None:
            # In case there are variables to be included in the save path
            save_path = eval(f"f'{self.save_path}'")

            dir = os.path.dirname(save_path)
            os.makedirs(dir, exist_ok=True)

            # Check whether to save model
            best_value = getattr(self, "best_value", None)
            if best_value is not None:
                if (
                    mode == "min"
                    and now_value >= best_value
                    or mode == "max"
                    and now_value <= best_value
                ):
                    return

            setattr(self, "best_value", now_value)

            # For DeepSpeed strategy
            if hasattr(self.trainer.strategy, "deepspeed_engine"):
                if not self.save_weights_only:
                    self.trainer.strategy.deepspeed_engine.save_checkpoint(
                        save_path, tag="deepspeed_ckpt"
                    )

                # Save a complete checkpoint
                if dist.get_rank() == 0:
                    basename = os.path.basename(save_path)
                    ckpt_path = os.path.join(save_path, f"{basename}.pt")
                    self.save_checkpoint(
                        ckpt_path, save_info, self.save_weights_only
                    )

            # For normal situation
            else:
                if dist.get_rank() == 0:
                    self.save_checkpoint(
                        save_path, save_info, self.save_weights_only
                    )

    def reset_metrics(self, stage) -> None:
        """
        Reset metrics for given stage
        Args:
            stage: "train", "valid" or "test"
        """
        for metric in self.metrics[stage].values():
            metric.reset()

    def get_log_dict(self, stage: str) -> dict:
        """
        Get log dict for the stage
        Args:
            stage: "train", "valid" or "test"

        Returns:
            A dictionary of metrics for the stage. Keys are metric names and values are metric values

        """
        return {
            name: metric.compute()
            for name, metric in self.metrics[stage].items()
        }

    def log_info(self, info: dict) -> None:
        """
        Record metrics during training and testing
        Args:
            info: dict of metrics
        """
        if getattr(self, "logger", None) is not None and dist.get_rank() == 0:
            info["learning_rate"] = self.lr_scheduler.get_last_lr()[0]
            info["epoch"] = self.epoch
            self.logger.log_metrics(info, step=self.step)

    def init_optimizers(self):
        copy_optimizer_kwargs = copy.deepcopy(self.optimizer_kwargs)

        # No decay for layer norm and bias
        no_decay = ["LayerNorm.weight", "bias"]
        weight_decay = copy_optimizer_kwargs.pop("weight_decay")

        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": weight_decay,
            },
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]

        optimizer_cls = eval(
            f"torch.optim.{copy_optimizer_kwargs.pop('class')}"
        )
        self.optimizer = optimizer_cls(
            optimizer_grouped_parameters,
            lr=self.lr_scheduler_kwargs["init_lr"],
            **copy_optimizer_kwargs,
        )

        tmp_kwargs = copy.deepcopy(self.lr_scheduler_kwargs)
        lr_scheduler = tmp_kwargs.pop("class")
        self.lr_scheduler = eval(lr_scheduler)(self.optimizer, **tmp_kwargs)

    def configure_optimizers(self):
        return {
            "optimizer": self.optimizer,
            "lr_scheduler": {
                "scheduler": self.lr_scheduler,
                "interval": "step",
                "frequency": 1,
            },
        }
