from pathlib import Path
from typing import Callable, Literal

import torch
from pydantic import BaseModel, model_validator
from transformers import TrainerCallback
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from transformers.training_args import TrainingArguments

from src.data.utils import CustomColName
from src.trainer.sst.threshold_update import ThresholdUpdater
from src.utils.logging_utils import get_logger

logger = get_logger(__name__, level="DEBUG")


class SstStateCallback(TrainerCallback):
    class Config(BaseModel):
        sst_start_strategy: IntervalStrategy
        """When to start SST"""

        sst_update_steps: int | float = -1
        """How often to update SST in update steps"""

        sst_start_steps: int | float = -1
        """When to start SST in num steps if sst_start_strategy is STEPS or num epochs if sst_start_strategy is EPOCH"""

        sst_state_save_steps: int | float = -1
        """How often to save the SST state in training update steps"""

        sst_state_log_steps: int | float = -1
        """How often to log to the tracker (W&B) the SST state in training update steps"""

        save_ex_losses: bool = True
        """Whether to save the example losses"""

        state_checkpoint_file: Path | None = None
        """Path to the state checkpoint file to load"""

        stats_compute_percentiles: bool = True
        """Whether to compute the percentiles of the example losses for logging purposes"""

        stats_compute_mean: bool = True
        """Whether to compute the mean of the example losses for logging purposes"""

        root_output_dir: Path
        """Output directory to save the weights or debug information"""

        output_folder_name: str = "sst_state"
        """Output folder name. Will be created under root_output_dir"""

        _output_dir: Path | None = None

        @model_validator(mode="after")
        def validate(self):
            if self.sst_start_strategy == IntervalStrategy.STEPS and self.sst_start_steps < 0:
                raise ValueError("sst_start_steps should be >= 0 if sst_start_strategy is STEPS")

            assert self.root_output_dir.exists()
            self._output_dir = self.root_output_dir / self.output_folder_name
            self._output_dir.mkdir(parents=True, exist_ok=True)

            return self

        @property
        def output_dir(self):
            assert self._output_dir is not None
            return self._output_dir

    def __init__(
        self,
        *,
        log_callback: Callable,
        threshold_updater: ThresholdUpdater | None = None,
        **kwargs,
    ) -> None:
        self.config = self.Config(**kwargs)
        self.threshold_updater = None if threshold_updater is None else threshold_updater
        self.log_callback = log_callback

        self.ex_losses = None

        self.sampled_ex_ids = torch.tensor([], dtype=torch.long, device="cpu")
        self.sampled_ds_ids = torch.tensor([], dtype=torch.long, device="cpu")

        self.min_ex_loss = float("inf")
        self.max_ex_loss = -float("inf")
        self.sampler = None
        self.sst_started = False
        self.train_max_steps = -1
        self.train_loss = None
        self.sst_start_global_step = -1  # When SST started

        if self.config.state_checkpoint_file is not None:
            self.load_state(self.config.state_checkpoint_file)

    def set_sampler(self, sampler):
        self.sampler = sampler

    def on_compute_train_loss_end(self, *, idx, loss, ids, ds_ids):
        assert self.sampler is not None, "Expected sampler to be set"
        if self.ex_losses is None:
            # On CPU to avoid GPU OOM
            # Dtype to match the one of the sampling weights
            self.ex_losses = torch.full((self.sampler.num_samples,), -1.0, dtype=torch.double, device="cpu")

        self.ex_losses[idx] = loss.to(self.ex_losses.dtype)

        self.sampled_ds_ids = torch.cat([self.sampled_ds_ids, ds_ids.to(self.sampled_ds_ids.dtype)])
        self.sampled_ex_ids = torch.cat([self.sampled_ex_ids, ids.to(self.sampled_ex_ids.dtype)])

        assert not self.ex_losses.requires_grad

    def set_current_train_loss(self, loss):
        self.train_loss = loss

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        # Check if we should start SST
        if not self.sst_started and self.should_start_sst(state):
            self.start_sst(state)

        # Check if we should update sample weights
        sst_update = False
        if self.sst_started and self.should_update_weights(state):
            if sst_update := self.update_sampling_weights(trainer_state=state):
                self.log_callback(self._get_state(trainer_state=state))

        self._maybe_log_save_state(trainer_state=state, sst_update=sst_update)

    def _maybe_log_save_state(self, trainer_state: TrainerState, sst_update: bool):
        assert isinstance(self.config.sst_state_save_steps, int)
        if trainer_state.global_step == 0:
            return

        if (
            not sst_update  # When sst update happens the state is already saved
            and self.config.sst_state_save_steps > 0
            and trainer_state.global_step % self.config.sst_state_save_steps == 0
        ):
            self.save_state(trainer_state=trainer_state)

        if self.config.sst_state_log_steps > 0 and trainer_state.global_step % self.config.sst_state_log_steps == 0:
            self.log_callback(self._get_state(trainer_state=trainer_state))

    def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        # A checkpoint is being saved, save in a subdirectory
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
        checkpoint_dir = self.config.root_output_dir / checkpoint_folder
        assert checkpoint_dir.exists(), f"Checkpoint directory does not exist: {checkpoint_dir}"
        output_dir = checkpoint_dir

        self.save_state(trainer_state=state, file_name="sst_state", output_dir=output_dir)

    def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        output_dir = self.config.root_output_dir
        self.save_state(trainer_state=state, file_name="sst_state", output_dir=output_dir)

    def on_train_begin(
        self, args: TrainingArguments, state: TrainerState, control: TrainerControl, train_dataloader, **kwargs
    ):
        def _is_train_steps_ratio(value):
            if value != -1 and isinstance(value, float):
                assert 0 <= value <= 1
                return True
            return False

        self.train_max_steps = state.max_steps

        if _is_train_steps_ratio(self.config.sst_start_steps):
            self.config.sst_start_steps = int(self.config.sst_start_steps * self.train_max_steps)

        if _is_train_steps_ratio(self.config.sst_state_save_steps):
            self.config.sst_state_save_steps = int(self.config.sst_state_save_steps * self.train_max_steps)

        if _is_train_steps_ratio(self.config.sst_state_log_steps):
            self.config.sst_state_log_steps = int(self.config.sst_state_log_steps * self.train_max_steps)

        if _is_train_steps_ratio(self.config.sst_update_steps):
            self.config.sst_update_steps = int(self.config.sst_update_steps * self.train_max_steps)

        self._save_ex_ids(train_dataloader)

    def _save_ex_ids(self, train_dataloader):
        if CustomColName.ID.value not in train_dataloader.dataset.column_names:
            return
        data = {
            CustomColName.ID.value: torch.tensor(train_dataloader.dataset[CustomColName.ID.value]),
            CustomColName.DS_ID.value: torch.tensor(train_dataloader.dataset[CustomColName.DS_ID.value]),
        }
        file_path = self.config.root_output_dir / "ex_ids.pth"
        torch.save(data, file_path)

        logger.info(f"Saved example ids to: {file_path}")

    def update_sampling_weights(self, *, trainer_state: TrainerState) -> bool:
        assert self.sst_started
        assert self.sampler is not None
        assert self.ex_losses is not None

        if self.threshold_updater is None or not self.threshold_updater.step(
            curr_train_loss=self.train_loss,
            ex_losses=self.ex_losses[self.ex_losses >= 0],  # Select only the non-negative losses (seen examples)
        ):
            # Either the threshold is not set or no update is needed
            return False

        assert self.threshold_updater.threshold is not None

        sampling_weights = torch.zeros_like(self.ex_losses, dtype=torch.double)

        # Increase the weights for losses below the threshold
        left_mask = (0 <= self.ex_losses) & (self.ex_losses <= self.threshold_updater.threshold)
        sampling_weights[left_mask] = (self.ex_losses[left_mask] - self.min_ex_loss) / (
            self.threshold_updater.threshold - self.min_ex_loss
        )

        # Decrease the weights for losses above the threshold
        right_mask = self.ex_losses > self.threshold_updater.threshold
        sampling_weights[right_mask] = (self.max_ex_loss - self.ex_losses[right_mask]) / (
            self.max_ex_loss - self.threshold_updater.threshold
        )

        # Handle unseen examples: set the weights to 1.0 to ensure they are sampled
        # This only happens when the warmup is under 1 epoch
        unseen_mask = self.ex_losses < 0
        if unseen_mask.sum() > 0:
            assert self.config.sst_start_strategy != IntervalStrategy.EPOCH, "We should have seen all examples"
            sampling_weights[unseen_mask] = 1.0

        # Normalize the weights
        sampling_weights_sum = sampling_weights.sum()
        assert sampling_weights_sum > 0, "Expected sampling_weights_sum to be > 0"
        sampling_weights /= sampling_weights_sum

        assert torch.isclose(sampling_weights.sum(), torch.tensor(1.0, dtype=torch.double))

        self.sampler.set_weights(sampling_weights)

        self.save_state(trainer_state=trainer_state)

        return True

    def should_start_sst(self, trainer_state: TrainerState) -> bool:
        if self.config.sst_start_strategy == IntervalStrategy.NO:
            return False

        assert (
            self.config.sst_start_steps >= 0
        ), f"Expected sst_start_steps to be >= 0. Got: {self.config.sst_start_steps}"

        if self.config.sst_start_strategy == IntervalStrategy.STEPS:
            return trainer_state.global_step >= self.config.sst_start_steps
        elif self.config.sst_start_strategy == IntervalStrategy.EPOCH:
            assert trainer_state.epoch is not None, "Expected epoch to be set"
            return trainer_state.epoch >= self.config.sst_start_steps
        else:
            raise ValueError(f"Invalid start strategy: {self.config.sst_start_strategy}")

    def start_sst(self, trainer_state: TrainerState):
        # TODO: log to wandb
        assert not self.sst_started, "Expected SST to not be started"
        assert self.sampler is not None, "Expected weighted sampler to be set"

        self.sst_started = True
        self.sst_start_global_step = trainer_state.global_step
        logger.info(f"Started SST at global step: {self.sst_start_global_step}")

    def should_update_weights(self, trainer_state: TrainerState) -> bool:
        assert self.sst_started, "SST should be started before updating weights"
        return (trainer_state.global_step - self.sst_start_global_step) % self.config.sst_update_steps == 0

    def _compute_percentiles(self, *, ex_losses):
        percentiles = torch.tensor(
            [0.25, 0.5, 0.75, 0.90, 0.95, 0.99],
            device=ex_losses.device,
            dtype=ex_losses.dtype,
        )
        results = torch.quantile(ex_losses, percentiles)
        result_dict = {f"ex_loss_P{p.item()}": v.item() for p, v in zip(percentiles, results)}
        return result_dict

    def _get_state(self, trainer_state: TrainerState, for_save: bool = False):
        # TODO: Move this later to be efficient
        assert self.ex_losses is not None

        # Select only the non-negative losses (seen examples)
        _ex_losses = self.ex_losses[self.ex_losses >= 0]
        self.min_ex_loss = _ex_losses.min()
        self.max_ex_loss = _ex_losses.max()

        state = {
            "sst_started": 1 if self.sst_started else 0,
            "sst_start_global_step": self.sst_start_global_step,
            "min_ex_loss": self.min_ex_loss.item(),
            "max_ex_loss": self.max_ex_loss.item(),
            "global_step": trainer_state.global_step,
            "epoch": trainer_state.epoch,
        }
        if self.config.stats_compute_percentiles:
            state.update(self._compute_percentiles(ex_losses=_ex_losses))

        if self.config.stats_compute_mean:
            state["mean_ex_loss"] = _ex_losses.mean().item()

        if for_save:
            # Save the SST state
            if self.sampler is not None and hasattr(self.sampler, "weights"):
                sampling_weights = self.sampler.weights
            else:
                sampling_weights = None

            state.update(
                {
                    "ex_losses": self.ex_losses,
                    "sampling_weights": sampling_weights,
                    "sampled_ex_ids": self.sampled_ex_ids,
                    "sampled_ds_ids": self.sampled_ds_ids,
                }
            )
        if self.threshold_updater is not None:
            state.update(self.threshold_updater._get_state())

        return state

    def save_state(
        self,
        *,
        trainer_state: TrainerState,
        file_name: str | None = None,
        output_dir: Path | None = None,
    ):
        state = self._get_state(trainer_state=trainer_state, for_save=True)
        file_name = (file_name or f"global_step__{trainer_state.global_step}") + ".pth"
        output_dir = output_dir or self.config.output_dir
        state_file = output_dir / file_name
        torch.save(state, state_file)

        logger.info(f"Saved SST state to: {state_file}")

    def load_state(self, resume_file: Path):
        state = torch.load(resume_file)
        self.ex_losses = state["ex_losses"]
        self.min_ex_loss = self.ex_losses.min()
        self.max_ex_loss = self.ex_losses.max()
        logger.info(f"Loaded SST state from: {resume_file}")
