# Copyright (c) Alibaba, Inc. and its affiliates.
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from peft import PeftModel
from transformers import PreTrainedModel
from trl import KTOTrainer as HFKTOTrainer

from swift.utils import get_logger
from ..mixin import SwiftMixin
from .rlhf_mixin import RLHFTrainerMixin

logger = get_logger()

del HFKTOTrainer.__init__


class KTOTrainer(RLHFTrainerMixin, SwiftMixin, HFKTOTrainer):

    def __init__(
        self,
        model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
        ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
        *_args,
        **kwargs
    ):
        args = kwargs["args"]
        args.disable_dropout = True
        self.desirable_weight = args.desirable_weight
        self.undesirable_weight = args.undesirable_weight
        self.precompute_ref_log_probs = args.precompute_ref_log_probs
        self.is_peft_model = isinstance(model, PeftModel)
        if hasattr(args, "loss_type"):
            self.loss_type = args.loss_type
        else:
            self.loss_type = "kto"

        self.ref_adapter_name = None
        # Not all losses require a KL calculation
        self.calculate_KL = True
        if self.loss_type in ["apo_zero_unpaired"]:
            self.calculate_KL = False
        super().__init__(model, ref_model, *_args, **kwargs)

    def forward(
        self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
    ) -> Tuple[
        torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor
    ]:
        is_kl = True

        def _add_data_hook(model, args, kwargs):
            nonlocal is_kl
            if is_kl:
                kwargs = {
                    k[len("KL_completion_") :]: v
                    for k, v in batch.items()
                    if k.startswith("KL_completion_")
                }
            else:
                kwargs = {
                    k[len("completion_") :]: v
                    for k, v in batch.items()
                    if k.startswith("completion_")
                }
            is_kl = not is_kl
            return (), kwargs

        @contextmanager
        def _patch_model_call():
            handle = model.register_forward_pre_hook(
                _add_data_hook, with_kwargs=True, prepend=True
            )

            try:
                yield
            finally:
                handle.remove()

        with _patch_model_call():
            return super().forward(model, batch)
