import torch
from torch.nn import functional as F
from transformers import Trainer

from egu.utils.utils import target_logprob


def batch_mean_nll_from_logits(logits, labels):
    """
    Per-example negative log-likelihood (mean over valid tokens).
    logits: [B, T, V]; labels: [B, T] with -100 masked
    Returns: [B] mean NLL per sequence
    """
    logits = logits[:, :-1, :].contiguous()
    labels = labels[:, 1:].contiguous().long()

    valid = labels.ne(-100)
    vocab = logits.size(-1)

    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
    # CE expects [B*T, V] and [B*T]
    ce_tok = loss_fct(logits.view(-1, vocab), labels.view(-1)).view_as(
        labels
    )  # [B,T-1]
    # Sum over valid tokens, divide by count -> per-example mean NLL
    tok_counts = valid.sum(dim=-1).clamp_min(1)
    nll_per_ex = (ce_tok * valid).sum(dim=-1) / tok_counts
    return nll_per_ex  # [B]


class NPOTrainer(Trainer):
    """
    - model   : policy to be updated
    - ref_model: fixed reference model (requires_grad = False)
    - beta    : temperature hyperparameter
    """

    def __init__(self, *args, ref_model=None, beta=0.1, **kwargs):
        super().__init__(*args, **kwargs)
        assert ref_model is not None, "ref_model must be supplied"
        self.ref_model = ref_model.eval()
        for p in self.ref_model.parameters():
            p.requires_grad = False
        self.beta = beta

    # def _wrap_model(self, model, training=True):
    # model = super()._wrap_model(model, training)
    # self.ref_model.to(next(model.parameters()).device)
    # return model
    # model = super()._wrap_model(model, training)
    # self.ref_model = self.ref_model.to(self.args.device)
    # return model

    def save_model(self, output_dir=None, _internal_call=True):
        output_dir = output_dir or self.args.output_dir

        if hasattr(self.model, "save_checkpoint"):
            self.model.save_checkpoint(output_dir)
            print(f"[DeepSpeed] Saved checkpoint at {output_dir}")
        else:
            super().save_model(output_dir, _internal_call)

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        f_ids = inputs["forget_input_ids"]
        f_mask = inputs["forget_attention_mask"]
        f_lbls = inputs["forget_labels"]

        # policy forward
        out_cur = model(input_ids=f_ids, attention_mask=f_mask, labels=f_lbls)
        nll_cur = batch_mean_nll_from_logits(out_cur.logits, f_lbls)  # [B]

        # reference forward (no grad)
        with torch.no_grad():
            out_ref = self.ref_model(
                input_ids=f_ids, attention_mask=f_mask, labels=f_lbls
            )
            nll_ref = batch_mean_nll_from_logits(out_ref.logits, f_lbls)  # [B]

        neg_log_ratios = nll_cur - nll_ref  # equals -(log p_pi - log p_ref)
        beta = getattr(self, "beta", 0.1)

        # (optional but helpful) stabilize: center per batch
        neg_log_ratios = neg_log_ratios - neg_log_ratios.mean().detach()

        loss = -F.logsigmoid(beta * neg_log_ratios).mean() * (2.0 / beta)
        self.log(
            {
                "npo_neg_log_ratio_mean": neg_log_ratios.mean().item(),
                "npo_neg_log_ratio_std": neg_log_ratios.std().item(),
            }
        )
        return (loss, None) if return_outputs else loss

    # def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
    #
    #     input_ids = inputs["forget_input_ids"]
    #     attention_mask = inputs["forget_attention_mask"]
    #     labels = inputs["forget_labels"]
    #
    #     # pi_theta
    #     # log_pi_theta  =
    #     # log_pi_ref_z =
    #     log_pi_theta = target_logprob(model, input_ids, attention_mask, labels)
    #
    #     with torch.no_grad():
    #
    #         log_ref_theta = target_logprob(
    #             self.ref_model, input_ids, attention_mask, labels
    #         )
    #
    #     print(log_pi_theta.device)  # cuda
    #     print(log_ref_theta.device)  # cpu
    #
    #     log_ratio = log_pi_theta - log_ref_theta
    #
    #     loss = -F.logsigmoid(log_ratio * -self.beta).mean() * 2 / self.beta
    #     self.log({"npo_log_ratio": log_ratio.mean().item()})
    #     return (loss, None) if return_outputs else loss
