import torch
import os.path as osp
from copy import deepcopy
from typing import Any, Dict, List, Optional, Type
from torchmetrics import MinMetric, MeanMetric
import lightning as L
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.loops.fit_loop import _FitLoop
from lightning.pytorch.loops.loop import _Loop


from src.utils.data_processing import denormalize_data
from src.datamodules.kfold_biomechEst_datamodule import BaseKFoldDataModule


class EnsembleVotingModel(L.LightningModule):
    def __init__(
        self,
        model_cls: Type[L.LightningModule],
        checkpoint_paths: List[str],
        iModel: int,
    ) -> None:
        super().__init__()
        # Create `num_folds` models with their associated fold weights
        self.models = torch.nn.ModuleList(
            [model_cls.load_from_checkpoint(p) for p in checkpoint_paths]
        )
        self.test_loss = MeanMetric()
        self.norm_test_mse = MeanMetric()
        self.criterion = torch.nn.MSELoss()
        self.iModel = iModel
        print("\n- Calculating avg performans of the folds!")

    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
        # Compute the averaged predictions over the `num_folds` models.
        preds = torch.stack([m(batch[0]) for m in self.models]).mean(0)
        loss = self.criterion(preds, batch[1])
        self.test_loss(loss)
        self.log(
            f"allFolds_avg_val_loss/model_{self.iModel}",
            self.test_loss,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )

        # test real (denormalized) rmse
        y_hat = denormalize_data(preds)
        y_val = denormalize_data(batch[1])
        self.norm_test_mse(self.criterion(y_hat, y_val))
        # self.log("allFolds_avg_val/norm_test_mse", self.norm_test_mse, on_step=False, on_epoch=True, prog_bar=True)

    def test_epoch_end(self, outputs: List[Any]):
        # calculating test real (denormalized) rmse
        rmse = torch.sqrt(self.norm_test_mse.compute())
        self.log(f"allFolds_avg_val_rmse/model_{self.iModel}", rmse, prog_bar=True)


class KFoldLoop(_Loop):
    def __init__(self, num_folds: int, export_path: str, iModel: int) -> None:
        super().__init__()
        self.num_folds = num_folds
        self.current_fold: int = 0
        self.export_path = export_path
        self.iModel = iModel

    @property
    def done(self) -> bool:
        return self.current_fold >= self.num_folds

    def connect(self, fit_loop: _FitLoop) -> None:
        self.fit_loop = fit_loop

    def reset(self) -> None:
        """Nothing to reset in this loop."""

    def on_run_start(self, *args: Any, **kwargs: Any) -> None:
        """Used to call `setup_folds` from the `BaseKFoldDataModule` instance and store the original weights of the
        model."""
        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
        self.trainer.datamodule.setup_folds(self.num_folds)
        self.lightning_module_state_dict = deepcopy(
            self.trainer.lightning_module.state_dict()
        )

    def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
        """Used to call `setup_fold_index` from the `BaseKFoldDataModule` instance."""
        print(f"STARTING FOLD {self.current_fold}")
        assert isinstance(self.trainer.datamodule, BaseKFoldDataModule)
        self.trainer.datamodule.setup_fold_index(self.current_fold)

    def advance(self, *args: Any, **kwargs: Any) -> None:
        """Used to the run a fitting and testing on the current hold."""
        self._reset_fitting()  # requires to reset the tracking stage.
        self.fit_loop.run()

        self._reset_testing()  # requires to reset the tracking stage.

        # the test loop normally expects the model to be the pure LightningModule, but since we are running the
        # test loop during fitting, we need to temporarily unpack the wrapped module
        # wrapped_model = self.trainer.strategy.model
        # self.trainer.strategy.model = self.trainer.strategy.lightning_module
        # self.trainer.test_loop.run()
        # self.trainer.strategy.model = wrapped_model
        self.current_fold += 1  # increment fold tracking number.

    def on_advance_end(self) -> None:
        """Used to save the weights of the current fold and reset the LightningModule and its optimizers."""
        self.trainer.save_checkpoint(
            osp.join(self.export_path, f"model.{self.current_fold}.pt")
        )
        # restore the original weights + optimizers and schedulers.
        self.trainer.lightning_module.load_state_dict(self.lightning_module_state_dict)
        self.trainer.strategy.setup_optimizers(self.trainer)
        self.replace(fit_loop=FitLoop)

    def on_run_end(self) -> None:
        """Used to compute the performance of the ensemble model on the test set."""
        checkpoint_paths = [
            osp.join(self.export_path, f"model.{f_idx + 1}.pt")
            for f_idx in range(self.num_folds)
        ]
        voting_model = EnsembleVotingModel(
            type(self.trainer.lightning_module), checkpoint_paths, self.iModel
        )
        voting_model.trainer = self.trainer
        # This requires to connect the new model and move it the right device.
        self.trainer.strategy.connect(voting_model)
        self.trainer.strategy.model_to_device()
        self.trainer.test_loop.run()

    def on_save_checkpoint(self) -> Dict[str, int]:
        return {"current_fold": self.current_fold}

    def on_load_checkpoint(self, state_dict: Dict) -> None:
        self.current_fold = state_dict["current_fold"]

    def _reset_fitting(self) -> None:
        self.trainer.reset_train_dataloader()
        self.trainer.reset_val_dataloader()
        self.trainer.state.fn = TrainerFn.FITTING
        self.trainer.training = True

    def _reset_testing(self) -> None:
        self.trainer.reset_test_dataloader()
        self.trainer.state.fn = TrainerFn.TESTING
        self.trainer.testing = True

    def __getattr__(self, key) -> Any:
        # requires to be overridden as attributes of the wrapped loop are being accessed.
        if key not in self.__dict__:
            return getattr(self.fit_loop, key)
        return self.__dict__[key]

    def __setstate__(self, state: Dict[str, Any]) -> None:
        self.__dict__.update(state)
