# Reference
# - https://github.com/huggingface/trl/blob/d633c4337f75bf359f27e680b6d9bd7343e7bacb/trl/trainer/reward_trainer.py
# - https://github.com/huggingface/trl/blob/d633c4337f75bf359f27e680b6d9bd7343e7bacb/trl/trainer/reward_config.py
# - https://github.com/tjoo512/belief-matching-framework

import inspect
import os
import warnings
from collections import defaultdict
from dataclasses import FrozenInstanceError, replace
from typing import Any, Callable, Optional, Union

import pandas as pd
import torch
import torch.nn as nn
from accelerate import PartialState
from accelerate.utils import gather_object
from datasets import Dataset
from transformers import (
    BaseImageProcessor,
    DataCollator,
    FeatureExtractionMixin,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    ProcessorMixin,
    Trainer,
    is_wandb_available,
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_pt_utils import nested_detach
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available

from trl.data_utils import maybe_apply_chat_template
from trl.trainer.utils import (
    RewardDataCollatorWithPadding,
    compute_accuracy,
    decode_and_strip_padding,
    disable_dropout_in_model,
    generate_model_card,
    get_comet_experiment_url,
    log_table_to_comet_experiment,
    print_rich_table,
)
from dataclasses import dataclass, field
from typing import Optional

from transformers import TrainingArguments


if is_peft_available():
    from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training

if is_wandb_available():
    import wandb


def _tokenize(batch: dict[str, list[Any]], tokenizer: PreTrainedTokenizerBase) -> dict[str, list[Any]]:
    chosen = tokenizer(
        batch["chosen"],
        truncation=True,
        max_length=batch.get("max_length", None),
        padding=False,
        return_attention_mask=True,
    )
    rejected = tokenizer(
        batch["rejected"],
        truncation=True,
        max_length=batch.get("max_length", None),
        padding=False,
        return_attention_mask=True,
    )

    return {
        "input_ids_chosen": chosen["input_ids"],
        "attention_mask_chosen": chosen["attention_mask"],
        "input_ids_rejected": rejected["input_ids"],
        "attention_mask_rejected": rejected["attention_mask"],
        "icl_counts": batch["icl_counts"],
    }



@dataclass
class ICRMConfig(TrainingArguments):
    r"""
    Configuration class for the [`ICRMTrainer`].

    Using [`~transformers.HfArgumentParser`] we can turn this class into
    [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
    command line.

    Parameters:
        max_length (`int` or `None`, *optional*, defaults to `1024`):
            Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
            limit. This argument is required if you want to use the default data collator.
        disable_dropout (`bool`, *optional*, defaults to `True`):
            Whether to disable dropout in the model.
        dataset_num_proc (`int`, *optional*, defaults to `None`):
            Number of processes to use for processing the dataset.
        center_rewards_coefficient (`float`, *optional*, defaults to `None`):
            Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
            https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
        remove_unused_columns (`bool`, *optional*, defaults to `False`):
            Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if
            the dataset is pretokenized.
        kl_penalty_lambda (`float`, *optional*, defaults to 0.01):
            The weight for KL penalty. `lambda` in Equation (9)
        beta_prior_param (`float`, *optional*, defaults to 1.0):
            The parameter for the Beta prior. Used for both alpha_0 and beta_0 in Equation (9)
    """

    max_length: Optional[int] = field(
        default=32768,
        metadata={
            "help": "Maximum length of the sequences (prompt + completion) in the batch, filters out entries that "
            "exceed the limit. This argument is required if you want to use the default data collator."
        },
    )
    disable_dropout: bool = field(
        default=True,
        metadata={"help": "Whether to disable dropout in the model and reference model."},
    )
    dataset_num_proc: Optional[int] = field(
        default=None,
        metadata={"help": "Number of processes to use for processing the dataset."},
    )
    center_rewards_coefficient: Optional[float] = field(
        default=None,
        metadata={
            "help": "Coefficient to incentivize the reward model to output mean-zero rewards (proposed by "
            "https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`."
        },
    )
    remove_unused_columns: bool = field(
        default=False,
        metadata={
            "help": "Whether to remove the columns that are not used by the model's forward pass. Can be `True` only "
            "if the dataset is pretokenized."
        },
    )

    # ICRM-specific hyperparameters
    kl_penalty_lambda: float = field(
        default=0.01,
        metadata={
            "help": "Weight for KL penalty in ICRM loss."
        }
    )
    beta_prior_param: float = field(
        default=1.0,
        metadata={
            "help": "The parameter for Beta prior in ICRM loss (both alpha_0 and beta_0)."
        }
    )

class SafeDigamma(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, eps):
        ctx.eps = eps
        ctx.save_for_backward(x)
        return torch.digamma(x)

    @staticmethod
    def backward(ctx, grad_output):
        (x,) = ctx.saved_tensors
        eps = ctx.eps
        x_safe = x.clamp(min=eps)
        grad = grad_output * torch.polygamma(1, x_safe)
        return grad, None

class ICRMLoss(nn.Module):
    def __init__(self, prior_alpha: float, prior_beta: float, eps: float = 1e-6, tau_max: float = 1e6):
        super().__init__()
        self.prior_alpha = float(prior_alpha)
        self.prior_beta  = float(prior_beta)
        self.eps = float(eps)
        self.tau_max = float(tau_max)

    def forward(self, alpha_q: torch.Tensor, beta_q: torch.Tensor, lambda_m: torch.Tensor) -> torch.Tensor:
        alpha_q = alpha_q.clamp(min=self.eps, max=self.tau_max)
        beta_q  = beta_q.clamp(min=self.eps, max=self.tau_max)

        psi_a   = torch.digamma(alpha_q)
        psi_b   = torch.digamma(beta_q)
        psi_sum = torch.digamma(alpha_q + beta_q)

        recon = -(psi_a - psi_sum)

        pa = torch.full_like(alpha_q, self.prior_alpha)
        pb = torch.full_like(beta_q,  self.prior_beta)

        log_norm_q = torch.lgamma(alpha_q + beta_q) - torch.lgamma(alpha_q) - torch.lgamma(beta_q)
        log_norm_p = torch.lgamma(pa + pb) - torch.lgamma(pa) - torch.lgamma(pb)

        kl = (log_norm_q - log_norm_p) \
           + (alpha_q - pa) * (psi_a - psi_sum) \
           + (beta_q  - pb) * (psi_b - psi_sum)

        return (recon + lambda_m * kl).mean()
        
class ICRMTrainer(Trainer):
    _tag_names = ["trl", "icrm-trainer"]

    def __init__(
        self,
        model: Optional[Union[PreTrainedModel, nn.Module]] = None,
        args: Optional[ICRMConfig] = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
        processing_class: Optional[
            Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
        ] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
        callbacks: Optional[list[TrainerCallback]] = None,
        optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
            None,
            None,
        ),
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
        peft_config: Optional[dict] = None,
    ):
        if not is_peft_available() and peft_config is not None:
            raise ValueError(
                "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
            )
        elif is_peft_available() and peft_config is not None:
            if not isinstance(model, PeftModel):
                if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
                    _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
                        inspect.signature(prepare_model_for_kbit_training).parameters
                    )

                    prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}

                    if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
                        warnings.warn(
                            "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
                            "please update to the latest version of peft to use `gradient_checkpointing_kwargs`.",
                            UserWarning,
                        )
                    elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
                        prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs

                    model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)

                model = get_peft_model(model, peft_config)

        if args.disable_dropout:
            disable_dropout_in_model(model)

        if compute_metrics is None:
            compute_metrics = compute_accuracy

        if data_collator is None:
            if processing_class is None:
                raise ValueError(
                    "A processing_class must be specified when using the default RewardDataCollatorWithPadding"
                )

            max_length = args.max_length

            data_collator = ICRMDataCollatorWithPadding(processing_class)

            if args.remove_unused_columns:
                try:
                    args.remove_unused_columns = False
                except FrozenInstanceError:
                    args = replace(args, remove_unused_columns=False)
                warnings.warn(
                    "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your ICRMConfig"
                    " we have set it for you, but you should do it yourself in the future.",
                    UserWarning,
                )

            self.use_reward_data_collator = True
        else:
            self.use_reward_data_collator = False

        model.warnings_issued["estimate_tokens"] = True

        if "input_ids_chosen" not in train_dataset.column_names:
            with PartialState().main_process_first():
                fn_kwargs = {"tokenizer": processing_class}
                train_dataset = train_dataset.map(
                    _tokenize,
                    batched=True,
                    fn_kwargs=fn_kwargs,
                    num_proc=args.dataset_num_proc,
                )
                train_dataset = train_dataset.filter(
                    lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
                    num_proc=args.dataset_num_proc,
                )
                if eval_dataset is not None:
                    eval_dataset = eval_dataset.map(
                        _tokenize,
                        fn_kwargs=fn_kwargs,
                        batched=True,
                        num_proc=args.dataset_num_proc,
                    )
                    eval_dataset = eval_dataset.filter(
                        lambda x: len(x["input_ids_chosen"]) <= max_length
                        and len(x["input_ids_rejected"]) <= max_length,
                        num_proc=args.dataset_num_proc,
                    )

        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            processing_class=processing_class,
            model_init=model_init,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        )

        if hasattr(self.model, "add_model_tags"):
            self.model.add_model_tags(self._tag_names)

        # Initialize the corrected loss function with prior from config
        self.icrm_loss_fn = ICRMLoss(
            prior_alpha=args.beta_prior_param,
            prior_beta=args.beta_prior_param,
        )
        self.logging_accumulator = defaultdict(list)

    def log(self, logs: dict[str, float], *args, **kwargs) -> None:
        """
        NEW: Step 3 -> Override the log method to process the accumulator.
        """
        # Add averaged metrics from the accumulator to the logs
        for key, values in self.logging_accumulator.items():
            if values:
                logs[f"train/{key}"] = sum(values) / len(values)
        
        # Clear the accumulator for the next logging window
        self.logging_accumulator.clear()
        
        # Call the original Trainer's log method
        super().log(logs, *args, **kwargs)

    def compute_loss(
        self,
        model,
        inputs: dict[str, Union[torch.Tensor, Any]],
        num_items_in_batch=None,
        return_outputs=False,
    ):
        out_w = model(
            input_ids=inputs["input_ids_chosen"],
            attention_mask=inputs["attention_mask_chosen"],
            return_dict=True,
        )
        out_l = model(
            input_ids=inputs["input_ids_rejected"],
            attention_mask=inputs["attention_mask_rejected"],
            return_dict=True,
        )
    
        logits_w = out_w["logits"]
        logits_l = out_l["logits"]
    
        # utility and evidence logits per sequence
        u_w = logits_w[:, 0]
        u_l = logits_l[:, 0]
        s_w = logits_w[:, 1]
        s_l = logits_l[:, 1]
    
        # Preference mean and concentration factor
        # We set clamping for `tau` to 1e3 just in case, but it doesn't reach that high in practice.
        mu  = torch.sigmoid(u_w - u_l)
        tau = (torch.nn.functional.softplus(s_w) + torch.nn.functional.softplus(s_l) + 1.0).clamp(max=1e3)
    
        # Reparameterization for the Beta posterior Beta(alpha_q, beta_q)
        alpha_q = mu * tau
        beta_q  = (1.0 - mu) * tau
    
        # λ(M) schedule
        icl_counts = inputs["icl_counts"].to(self.accelerator.device).clamp_min(1.0)
        eps = 1e-8
        lambda_m = self.args.kl_penalty_lambda / (icl_counts + eps)
    
        loss = self.icrm_loss_fn(alpha_q, beta_q, lambda_m)
        
        self.logging_accumulator["mu_mean"].append(self.accelerator.gather(mu).detach().mean().item())
        self.logging_accumulator["tau_mean"].append(self.accelerator.gather(tau).detach().mean().item())
        self.logging_accumulator["mu_std"].append(self.accelerator.gather(mu).detach().std().item())
        self.logging_accumulator["tau_std"].append(self.accelerator.gather(tau).detach().std().item())
        self.logging_accumulator["rewards_chosen_mean"].append(self.accelerator.gather(u_w).detach().mean().item())
        self.logging_accumulator["rewards_rejected_mean"].append(self.accelerator.gather(u_l).detach().mean().item())
        self.logging_accumulator["lambda_m_mean"].append(self.accelerator.gather(lambda_m).detach().mean().item())
    
        if return_outputs:
            return loss, {
                "mu": mu,
                "tau": tau,
                "rewards_chosen": u_w,
                "rewards_rejected": u_l,
            }
        return loss

    def prediction_step(
        self,
        model,
        inputs: dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[list[str]] = None,
    ):
        inputs = self._prepare_inputs(inputs)
        if ignore_keys is None:
            ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) if hasattr(self.model, "config") else []
    
        with torch.no_grad():
            loss, out = self.compute_loss(model, inputs, return_outputs=True)
    
        if prediction_loss_only:
            return (loss.detach(), None, None)
    
        p_chosen = out["mu"]
        probs = torch.stack([p_chosen, 1.0 - p_chosen], dim=1)
        labels = torch.zeros(probs.size(0), device=probs.device)
    
        return loss.detach(), probs.detach(), labels

    def evaluate(self, *args, **kwargs):
        num_print_samples = kwargs.pop("num_print_samples", 4)
        self.visualize_samples(num_print_samples)
        return super().evaluate(*args, **kwargs)

    def visualize_samples(self, num_print_samples: int):
        eval_dataloader = self.get_eval_dataloader()
        table = defaultdict(list)
        for _, inputs in enumerate(eval_dataloader):
            _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
            chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class)
            rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class)
            table["chosen_text"].extend(gather_object(chosen_text))
            table["rejected_text"].extend(gather_object(rejected_text))
            table["logits"].extend(
                gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
            )
            if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
                break
        df = pd.DataFrame(table)
        if self.accelerator.process_index == 0:
            print_rich_table(df[:num_print_samples])
            if "wandb" in self.args.report_to:
                import wandb

                if wandb.run is not None:
                    wandb.log({"completions": wandb.Table(dataframe=df)})

            if "comet_ml" in self.args.report_to:
                log_table_to_comet_experiment(
                    name="completions.csv",
                    table=df,
                )

    def create_model_card(
        self,
        model_name: Optional[str] = None,
        dataset_name: Optional[str] = None,
        tags: Union[str, list[str], None] = None,
    ):
        if not self.is_world_process_zero():
            return

        if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
            base_model = self.model.config._name_or_path
        else:
            base_model = None

        tags = tags or []
        if isinstance(tags, str):
            tags = [tags]

        if hasattr(self.model.config, "unsloth_version"):
            tags.append("unsloth")

        model_card = generate_model_card(
            base_model=base_model,
            model_name=model_name,
            hub_model_id=self.hub_model_id,
            dataset_name=dataset_name,
            tags=tags,
            wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
            comet_url=get_comet_experiment_url(),
            trainer_name="Reward",
        )

        model_card.save(os.path.join(self.args.output_dir, "README.md"))


@dataclass
class ICRMDataCollatorWithPadding:
    r"""
    ICRM DataCollator class that pads the inputs to the maximum length of the batch.

    Args:
        tokenizer (`PreTrainedTokenizerBase`):
            The tokenizer used for encoding the data.
        padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`):
            padding_strategy to pass to the tokenizer.
        pad_to_multiple_of (`int` or `None`, `optional`, defaults to `None`):
            If set will pad the sequence to a multiple of the provided value.
        return_tensors (`str`, `optional`, defaults to `"pt"`):
            The tensor type to use.
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str] = True
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
        features_chosen = []
        features_rejected = []
        margin = []
        icl_counts = []
        has_margin = "margin" in features[0]
        for feature in features:
            if (
                "input_ids_chosen" not in feature
                or "input_ids_rejected" not in feature
                or "attention_mask_chosen" not in feature
                or "attention_mask_rejected" not in feature
            ):
                raise ValueError(
                    "The features should include `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`"
                )

            features_chosen.append(
                {
                    "input_ids": feature["input_ids_chosen"],
                    "attention_mask": feature["attention_mask_chosen"],
                }
            )
            features_rejected.append(
                {
                    "input_ids": feature["input_ids_rejected"],
                    "attention_mask": feature["attention_mask_rejected"],
                }
            )
            icl_counts.append(
                feature["icl_counts"]
            )
            if has_margin:
                margin.append(feature["margin"])
        batch_chosen = self.tokenizer.pad(
            features_chosen,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )
        batch_rejected = self.tokenizer.pad(
            features_rejected,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )
        batch = {
            "input_ids_chosen": batch_chosen["input_ids"],
            "attention_mask_chosen": batch_chosen["attention_mask"],
            "input_ids_rejected": batch_rejected["input_ids"],
            "attention_mask_rejected": batch_rejected["attention_mask"],
            "icl_counts": torch.tensor(icl_counts, dtype=torch.float),
            "return_loss": True,
        }
        if has_margin:
            margin = torch.tensor(margin, dtype=torch.float)
            batch["margin"] = margin
        return batch