from typing import Callable

from src.data.utils import CustomColName
from src.trainer.chat_sft_trainer import ChatSFTTrainer
from src.trainer.sst.sst_state_callback import SstStateCallback
from src.utils.logging_utils import get_logger

logger = get_logger()


class SstTrainer(ChatSFTTrainer):
    class Config(ChatSFTTrainer.Config):
        pass

    def __init__(
        self,
        *,
        sampler_factory: Callable | None = None,
        sst_state_callback_factory: Callable | None = None,
        **kwargs,
    ):
        callbacks = kwargs.pop("callbacks", []) or []
        if sst_state_callback_factory is not None:
            self.sst_state: SstStateCallback = sst_state_callback_factory(log_callback=self.log)
            callbacks.append(self.sst_state)
            logger.info(f"Added SST State callback: {self.sst_state.__class__.__name__}")
        else:
            self.sst_state = None

        self.train_sampler = None

        super().__init__(callbacks=callbacks, **kwargs)

        self.config = self.Config(**kwargs)
        self.sampler_factory = sampler_factory

    def _get_train_sampler(self):
        if self.sampler_factory is None:
            self.train_sampler = super()._get_train_sampler()
        else:
            # Custom sampler (e.g. WeightedRandomSampler)
            self.train_sampler = self.sampler_factory(dataset=self.train_dataset)

        if self.sst_state is not None:
            self.sst_state.set_sampler(self.train_sampler)

        return self.train_sampler

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        ret_value = super().compute_loss(
            model,
            inputs,
            return_outputs=return_outputs,
            num_items_in_batch=num_items_in_batch,
        )
        loss = ret_value if not return_outputs else ret_value[0]
        if self.sst_state is not None and model.training:
            _loss = loss.view(-1).detach().cpu()
            if CustomColName.IDX.value in inputs:
                _idx = inputs[CustomColName.IDX.value].detach().cpu()
                _ids = inputs[CustomColName.ID.value].detach().cpu()
                _ds_ids = inputs[CustomColName.DS_ID.value].detach().cpu()
            else:
                _idx = None
                _ids = None
                _ds_ids = None
            self.sst_state.on_compute_train_loss_end(
                idx=_idx,
                loss=_loss,
                ids=_ids,
                ds_ids=_ds_ids,
            )
        return ret_value

    def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time):
        """Allow to get the train loss without the need to update the trainer code
        since this is called after each step in the training loop.
        """
        self.sst_state.set_current_train_loss(tr_loss.detach().cpu())
        return super()._maybe_log_save_evaluate(
            tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time
        )
