from pathlib import Path
from omegaconf import OmegaConf
import torch
from transformers import optimization as hf_optim
import pytorch_lightning as pl

from _modeling.configuration_tsp import TSPConfig
from _optimization import AdamW_NoBiasCorrection, split_params

ROOT_DIR = Path(__file__).parent.parent


class TrainingModule(pl.LightningModule):
    @classmethod
    def load_pretrained_model(
        cls, config, model_cls, state_dict=None, task: str = None, except_classifier=False, **model_kwargs
    ):
        if state_dict is not None:
            return model_cls.from_pretrained(
                pretrained_model_name_or_path=None,
                config=TSPConfig.from_config(config),
                state_dict=state_dict,
                **model_kwargs,
            )

        file_ext = Path(config.load_ckpt_path).suffix
        if file_ext:
            # Identify checkpoint path
            ckpt_path = ROOT_DIR / "checkpoints" / config.load_ckpt_path
            if file_ext == ".finetuning":
                task_ckpts = list(ckpt_path.glob(f"{task}*.ckpt"))
                _msg = f"Assure there's only one checkpoint for {task} in {ckpt_path}."
                assert len(task_ckpts) == 1, _msg
                ckpt_path = task_ckpts[0]
            # Load the checkpoint
            print(f"loading model checkpoint from: {ckpt_path}")
            ckpt = torch.load(ckpt_path, map_location="cpu")
            try:
                loaded_config = ckpt["hyper_parameters"]["config"]
            except:
                loaded_config = ckpt["hyper_parameters"]
            assert config.scale == loaded_config.scale
            # Load the model
            state_dict = {
                k.replace("model.", ""): v
                for k, v in ckpt["state_dict"].items()
                if k.startswith("model.")
            }  # Get state_dict of internal `transformes.PretrainedModel` model
            assert state_dict
            if except_classifier:
                for para_name in list(state_dict.keys()):
                    if "head.classifier" in para_name:
                        state_dict.pop(para_name)
            return model_cls.from_pretrained(
                pretrained_model_name_or_path=None,
                config=TSPConfig.from_config(loaded_config),
                state_dict=state_dict,
                **model_kwargs,
            )
        else:
            return model_cls.from_pretrained(
                config.load_ckpt_path, trust_remote_code=True
            )

    def configure_optimizers(self):
        param_groups = split_params(self.config, model=self)
        # Note weight decay and learning rate specified in param groups has higher priority
        if self.config.optimizer_bias_correction:
            optimizer = torch.optim.Adam(
                param_groups,
                weight_decay=self.config.weight_decay,
                eps=self.config.optimizer_eps,
            )
        else:
            optimizer = AdamW_NoBiasCorrection(
                param_groups,
                lr=self.config.learning_rate,
                weight_decay=self.config.weight_decay,
                eps=self.config.optimizer_eps,
            )
        total_steps = self.config.num_steps
        if not total_steps:
            total_steps = self.trainer.max_epochs * len(
                self.trainer.datamodule.train_dataloader()
            )
        assert bool(self.config.lr_warmup_steps) ^ bool(self.config.lr_warmup_fraction)
        num_warmup_steps = self.config.lr_warmup_steps
        if not num_warmup_steps:
            num_warmup_steps = total_steps * self.config.lr_warmup_fraction
        scheduler = {
            "scheduler": hf_optim.get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=num_warmup_steps,
                num_training_steps=total_steps,
            ),
            "interval": "step",
            "frequency": 1,
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
