import lightning as L
from pathlib import Path
from typing import Any

from saws.checkpointer import CheckpointManager
from saws.config.warmstart_config import WarmstartConfig


class Warmstarter:

    valid_warmstart_state_candidates = [
        "model",
        "optimizer",
        "torch_scheduler",
        "train_tokens",
    ]

    def __init__(self, fabric: L.fabric, config: WarmstartConfig, target_model: Any) -> None:
        self.fabric = fabric
        self.warmstart_config = config
        self.base_model_path = config.base_model_path
        self.target_model = target_model
        
    def trigger(self) -> dict[str, Any]:
        """Warmstart the target model using the base model.
        """
        ckpt_manager = CheckpointManager(
            fabric=self.fabric,
            save_dir=None,
            load_dir=Path(self.warmstart_config.base_model_path),
            update_every_k=None,
            save_every_k=1,  # crucial that one of the `save_*` is set to store `model_0.pth`
        )
        # Load the base model
        _states, _ = ckpt_manager.load_checkpoint(
            state=None,  # `None` here ensures `_states` does not overwrite the `state` values
            train_step=self.warmstart_config.warmstarting_args.get("base_model_step", None),
        )
        # if `restart_dataloader` is True, the train tokens seen would correspond to the value 
        # loaded from the base model checkpoint
        if self.warmstart_config.restart_dataloader:
            # Resume the dataloader, if `restart_dataloader` is False
            _states["train_tokens"] = 0

        # Execute the warmstart function (a partial function)
        self.warmstart_config.warmer()(_states["model"], self.target_model)
        _states["model"] = self.target_model

        _valid_state_list = self.valid_warmstart_state_candidates
        # Removing all other elements from the states except the model (and optimizer)
        _states = {k: v for k, v in _states.items() if k in _valid_state_list}
        if not self.warmstart_config.retain_optimizer:
            # Load the states only if the optimizer is to be retained
            del _states["optimizer"]
            if "torch_scheduler" in _states:
                del _states["torch_scheduler"]
        return _states