import torch
from abc import ABC, abstractmethod
from trl import GRPOTrainer

from grpo_regularization.featextract import compute_feature_space_distance_v1
from grpo_regularization.kl_divergence import kl_divergence_loss


class RegGRPOTrainer(GRPOTrainer, ABC):
    """
    GRPOTrainer + an auxiliary regularizer computed between:
      - the trainable policy (model)
      - a frozen reference model (frozen_model)
    """

    def __init__(
        self,
        *args,
        frozen_model=None,
        beta: float = 10.0,
        num_intermediate_layers: int = 5,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

        # If you already pass ref_model to GRPOTrainer, you can default to that.
        if frozen_model is None:
            frozen_model = getattr(self, "ref_model", None)

        assert frozen_model is not None, "Provide frozen_model=... (or ref_model in GRPOTrainer)."
        self.frozen_model = frozen_model
        self.frozen_model.eval()
        for p in self.frozen_model.parameters():
            p.requires_grad_(False)

        self.beta = beta
        self.num_intermediate_layers = num_intermediate_layers
        self._accum_reg_losses = []
        self._accum_base_losses = []

    @abstractmethod
    def compute_feature_space_distance(self, model_trained, model_frozen, input_ids, attention_mask, outputs=None):
        pass

    def _get_full_sequence_inputs(self, inputs):
        """
        GRPO batches vary across versions/collators.

        Common cases:
          1) inputs already has full 'input_ids'/'attention_mask'
          2) inputs has prompt+completion split fields; we concat them
        """
        if "input_ids" in inputs:
            input_ids = inputs["input_ids"]
            attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids))
            return input_ids, attention_mask

        # Some GRPO pipelines keep prompt/completion separate.
        prompt_ids = inputs.get("prompt_input_ids", None)
        completion_ids = inputs.get("completion_input_ids", None)
        if prompt_ids is not None and completion_ids is not None:
            input_ids = torch.cat([prompt_ids, completion_ids], dim=1)

            prompt_mask = inputs.get("prompt_attention_mask", torch.ones_like(prompt_ids))
            completion_mask = inputs.get("completion_attention_mask", torch.ones_like(completion_ids))
            attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
            return input_ids, attention_mask

        raise KeyError(
            "Couldn't find inputs for distance regularizer. Expected either "
            "('input_ids','attention_mask') or ('prompt_input_ids','completion_input_ids', ...)."
        )

    def log(self, logs, start_time=None):
        if self._accum_reg_losses:
            logs["reg_loss"] = sum(self._accum_reg_losses) / len(self._accum_reg_losses)
            self._accum_reg_losses.clear()
        if self._accum_base_losses:
            logs["base_loss"] = sum(self._accum_base_losses) / len(self._accum_base_losses)
            self._accum_base_losses.clear()
        return super().log(logs, start_time=start_time)

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # GRPO base loss (policy-gradient style objective)
        base_loss, outputs = super().compute_loss(model, inputs, return_outputs=True, **kwargs)

        input_ids, attention_mask = self._get_full_sequence_inputs(inputs)

        reg_loss = self.compute_feature_space_distance(
            model_trained=model,
            model_frozen=self.frozen_model,
            input_ids=input_ids,
            attention_mask=attention_mask,
            outputs=outputs,  # allow reuse of trained forward if available
        )

        total_loss = base_loss + (self.beta * reg_loss)

        self._accum_base_losses.append(base_loss.detach().float().cpu().item())
        self._accum_reg_losses.append(reg_loss.detach().float().cpu().item())

        return (total_loss, outputs) if return_outputs else total_loss


class LDIFSTrainer(RegGRPOTrainer):
    def compute_feature_space_distance(self, model_trained, model_frozen, input_ids, attention_mask, outputs=None):
        return compute_feature_space_distance_v1(
            model_trained=model_trained,
            model_frozen=model_frozen,
            input_ids=input_ids,
            attention_mask=attention_mask,
            num_intermediate_layers=self.num_intermediate_layers,
        )


class KLTrainer(RegGRPOTrainer):
    """
    Note: GRPOTrainer often already supports a KL-to-reference term internally
    (depending on TRL version/args). Use this only if you specifically want *your*
    KL implementation on logits.
    """
    def compute_feature_space_distance(self, model_trained, model_frozen, input_ids, attention_mask, outputs=None):
        # Reuse trained logits if the base GRPO forward returned them
        if outputs is not None and hasattr(outputs, "logits") and outputs.logits is not None:
            logits_trained = outputs.logits
        else:
            logits_trained = model_trained(input_ids=input_ids, attention_mask=attention_mask).logits

        with torch.no_grad():
            logits_frozen = model_frozen(input_ids=input_ids, attention_mask=attention_mask).logits

        return kl_divergence_loss(
            logits_trained,
            logits_frozen,
            attention_mask,
            reduction="batchmean",
            temperature=1.0,
        )