"""
Difficulty Metrics for GDO-DPO

Implements semantic complexity (Csem) and preference uncertainty (Upref)
as described in the paper.
"""

import torch
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple
from transformers import PreTrainedModel, PreTrainedTokenizer
from tqdm import tqdm


class DifficultyMetrics:
    """
    Computes semantic complexity and preference uncertainty for preference pairs.

    Following Definition 3.1 and 3.2 from the paper.
    """

    def __init__(
        self,
        reference_model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        device: str = "cuda",
        num_samples: int = 8,  # K in Equation 2
    ):
        """
        Args:
            reference_model: Reference policy π_ref for complexity estimation
            tokenizer: Tokenizer for the model
            device: Device to run computations on
            num_samples: Number of MC samples for Csem estimation (K in paper)
        """
        self.reference_model = reference_model
        self.tokenizer = tokenizer
        self.device = device
        self.num_samples = num_samples
        self.reference_model.eval()

    @torch.no_grad()
    def compute_semantic_complexity(
        self,
        prompts: List[str],
        max_new_tokens: int = 256,
        temperature: float = 1.0,
        batch_size: int = 4,
    ) -> np.ndarray:
        """
        Compute semantic complexity Csem(x) via predictive entropy.

        Following Definition 3.1 and Equation 2:
        Csem(x) ≈ -1/K Σ log π_ref(y_k|x)

        Args:
            prompts: List of input prompts
            max_new_tokens: Maximum tokens to generate for sampling
            temperature: Sampling temperature
            batch_size: Batch size for generation

        Returns:
            Array of complexity scores, one per prompt
        """
        complexities = []

        for prompt in tqdm(prompts, desc="Computing Csem"):
            total_log_prob = 0.0

            # Sample K responses from reference model
            for _ in range(self.num_samples):
                inputs = self.tokenizer(
                    prompt,
                    return_tensors="pt",
                    truncation=True,
                    max_length=512
                ).to(self.device)

                # Generate response
                outputs = self.reference_model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=True,
                    temperature=temperature,
                    pad_token_id=self.tokenizer.pad_token_id,
                    return_dict_in_generate=True,
                    output_scores=True,
                )

                # Compute log probability of generated sequence
                sequences = outputs.sequences[:, inputs.input_ids.shape[1]:]
                log_prob = 0.0

                for idx, token_id in enumerate(sequences[0]):
                    if idx < len(outputs.scores):
                        logits = outputs.scores[idx][0]
                        log_probs = F.log_softmax(logits, dim=-1)
                        log_prob += log_probs[token_id].item()

                total_log_prob += log_prob

            # Average negative log probability (entropy estimate)
            complexity = -total_log_prob / self.num_samples
            complexities.append(complexity)

        return np.array(complexities)

    @torch.no_grad()
    def compute_preference_uncertainty(
        self,
        prompts: List[str],
        chosen_responses: List[str],
        rejected_responses: List[str],
        reward_model: PreTrainedModel = None,
    ) -> np.ndarray:
        """
        Compute preference uncertainty Upref via reward margin.

        Following Definition 3.2 and Equation 3:
        Upref(y^w, y^l|x) = exp(-|r*(y^w|x) - r*(y^l|x)|)

        Args:
            prompts: List of prompts
            chosen_responses: List of preferred responses
            rejected_responses: List of rejected responses
            reward_model: Optional reward model. If None, uses reference model scores.

        Returns:
            Array of uncertainty scores, one per preference pair
        """
        uncertainties = []

        model_to_use = reward_model if reward_model is not None else self.reference_model

        for prompt, chosen, rejected in tqdm(
            zip(prompts, chosen_responses, rejected_responses),
            desc="Computing Upref",
            total=len(prompts)
        ):
            # Compute reward for chosen response
            chosen_text = prompt + chosen
            chosen_inputs = self.tokenizer(
                chosen_text,
                return_tensors="pt",
                truncation=True,
                max_length=1024
            ).to(self.device)

            chosen_outputs = model_to_use(**chosen_inputs)
            chosen_logits = chosen_outputs.logits

            # Get log probability
            prompt_len = len(self.tokenizer(prompt, return_tensors="pt").input_ids[0])
            chosen_labels = chosen_inputs.input_ids[:, 1:]
            chosen_logits_shifted = chosen_logits[:, :-1, :]

            chosen_log_probs = F.log_softmax(chosen_logits_shifted, dim=-1)
            chosen_reward = 0.0
            for i in range(prompt_len - 1, chosen_labels.shape[1]):
                if i < chosen_log_probs.shape[1]:
                    chosen_reward += chosen_log_probs[0, i, chosen_labels[0, i]].item()

            # Compute reward for rejected response
            rejected_text = prompt + rejected
            rejected_inputs = self.tokenizer(
                rejected_text,
                return_tensors="pt",
                truncation=True,
                max_length=1024
            ).to(self.device)

            rejected_outputs = model_to_use(**rejected_inputs)
            rejected_logits = rejected_outputs.logits

            rejected_labels = rejected_inputs.input_ids[:, 1:]
            rejected_logits_shifted = rejected_logits[:, :-1, :]

            rejected_log_probs = F.log_softmax(rejected_logits_shifted, dim=-1)
            rejected_reward = 0.0
            for i in range(prompt_len - 1, rejected_labels.shape[1]):
                if i < rejected_log_probs.shape[1]:
                    rejected_reward += rejected_log_probs[0, i, rejected_labels[0, i]].item()

            # Compute uncertainty as exp(-|margin|)
            margin = abs(chosen_reward - rejected_reward)
            uncertainty = np.exp(-margin)
            uncertainties.append(uncertainty)

        return np.array(uncertainties)

    def compute_rank_normalized_scores(
        self,
        scores: np.ndarray
    ) -> np.ndarray:
        """
        Convert absolute scores to rank-normalized [0, 1] scores.

        This ensures robustness to outliers and consistent threshold interpretation.

        Args:
            scores: Array of difficulty scores

        Returns:
            Rank-normalized scores in [0, 1]
        """
        ranks = np.argsort(np.argsort(scores))
        normalized = ranks / (len(scores) - 1)
        return normalized

    def precompute_dataset_difficulties(
        self,
        dataset: List[Dict],
        save_path: str = None,
    ) -> Dict[str, np.ndarray]:
        """
        Precompute difficulty metrics for entire dataset.

        Args:
            dataset: List of preference pairs with keys:
                     'prompt', 'chosen', 'rejected'
            save_path: Optional path to save computed metrics

        Returns:
            Dictionary with keys 'Csem', 'Upref', 'Rsem', 'Runc'
        """
        prompts = [sample['prompt'] for sample in dataset]
        chosen = [sample['chosen'] for sample in dataset]
        rejected = [sample['rejected'] for sample in dataset]

        print("Computing semantic complexity...")
        csem = self.compute_semantic_complexity(prompts)

        print("Computing preference uncertainty...")
        upref = self.compute_preference_uncertainty(prompts, chosen, rejected)

        # Rank normalize
        rsem = self.compute_rank_normalized_scores(csem)
        runc = self.compute_rank_normalized_scores(upref)

        results = {
            'Csem': csem,
            'Upref': upref,
            'Rsem': rsem,
            'Runc': runc,
        }

        if save_path:
            np.savez(save_path, **results)
            print(f"Saved difficulty metrics to {save_path}")

        return results
