import json
import random
from typing import List

import numpy as np
import torch
from einops import einsum
from torch.nn.utils.rnn import pad_sequence

import wandb
from evaluators import qwen3_grader, qwen25_grader

# ==================== Evaluation Criteria ====================
EVALUATION_CRITERIA = {
    "prompt_alignment": """Evaluate the following stories for alignment with the given prompt. 

Prompt: {question1} 

Story 1: {story1} 

Story 2: {story2}

Does the story fully address the requirements of the prompt? Are all key elements, details, constraints, or events implied or explicitly stated in the prompt incorporated into the story in a meaningful way? Does the story avoid ignoring or contradicting any part of the prompt? 

Respond in JSON with your concise reasoning under 'reasoning' and your final answer under 'final_answer' being either 1 or 2 depending on which story is better aligned with the prompt""",
    "plot_coherence_and_progression": """Evaluate the following stories for plot coherence and progression.
    
Story 1: {story1}

Story 2: {story2}

Does the story present a clear and logically consistent plot? Do events follow naturally from one another with a sense of cause and effect? Are transitions between scenes smooth, and does the narrative maintain momentum without unnecessary repetition or confusion?

Respond in JSON with your concise reasoning under 'reasoning' and your final answer under 'final_answer' being either 1 or 2 depending on which story has better plot coherence and progression""",
    "character_depth_and_motivation": """Evaluate the following stories for character depth and motivation.

Story 1: {story1}

Story 2: {story2}

Do the characters have believable goals, emotions, and internal conflicts? Are their actions motivated and consistent with their personalities? Does the story explore their development or transformation rather than keeping them static or purely functional?

Respond in JSON with your concise reasoning under 'reasoning' and your final answer under 'final_answer' being either 1 or 2 depending on which story has better character depth and motivation""",
    "narrative_tension_and_stakes": """Evaluate the following stories for narrative tension and stakes.

Story 1: {story1}

Story 2: {story2}

Does the story sustain a sense of tension, uncertainty, or conflict that keeps readers engaged? Are there meaningful stakes—emotional, physical, or moral—that make the outcome matter? Does the tension rise and evolve rather than remaining flat?

Respond in JSON with your concise reasoning under 'reasoning' and your final answer under 'final_answer' being either 1 or 2 depending on which story has better narrative tension and stakes""",
    "show_dont_tell": """Evaluate the following stories for "show, don't tell" technique.

Story 1: {story1}

Story 2: {story2}

Does the story rely on vivid imagery, sensory detail, and action to communicate emotion or meaning? Does it avoid excessive exposition and direct statements of feelings or morals? Are readers able to infer the deeper message through what is shown rather than told?

Respond in JSON with your concise reasoning under 'reasoning' and your final answer under 'final_answer' being either 1 or 2 depending on which story better utilizes the "show, don't tell" technique""",
    "emotional_resonance": """Evaluate the following stories for emotional resonance.

Story 1: {story1}

Story 2: {story2}

Does the story evoke a genuine emotional response? Are the emotions integrated naturally into the narrative rather than being forced or sentimental? Does the reader feel connected to the protagonist's experiences or internal journey?

Respond in JSON with your concise reasoning under 'reasoning' and your final answer under 'final_answer' being either 1 or 2 depending on which story has better emotional resonance""",
    "thematic_subtlety": """Evaluate the following stories for thematic subtlety.

Story 1: {story1}

Story 2: {story2}

Does the story communicate its themes organically rather than through explicit explanation? Are ideas and morals implied through events, imagery, or character growth? Does it avoid oversimplified messages or moralizing conclusions?

Respond in JSON with your concise reasoning under 'reasoning' and your final answer under 'final_answer' being either 1 or 2 depending on which story has better thematic subtletly""",
    "pacing_and_structure": """Evaluate the following stories for pacing and structure.

Story 1: {story1}

Story 2: {story2}

Does the story maintain good rhythm and flow across its beginning, middle, and end? Are slower sections purposeful and balanced with moments of tension or revelation? Does the structure build toward a satisfying emotional or narrative peak?

Respond in JSON with your concise reasoning under 'reasoning' and your final answer under 'final_answer' being either 1 or 2 depending on which story has better pacing and structure""",
    "atmosphere_and_sensory_detail": """Evaluate the following stories for atmosphere and sensory detail.

Story 1: {story1}

Story 2: {story2}

Does the story effectively create mood and setting through sensory language? Are sights, sounds, textures, or other details used to evoke tone and emotion? Does the environment reflect or enhance the psychological state of the characters?

Respond in JSON with your concise reasoning under 'reasoning' and your final answer under 'final_answer' being either 1 or 2 depending on which story has better atmosphere and sensory detail""",
    "dialogue_and_relationships": """Evaluate the following stories for dialogue and relationships.

Story 1: {story1}

Story 2: {story2}

Is the dialogue natural and revealing of character or tension? Does it serve to advance the plot or deepen relationships between characters? Are interactions meaningful, showing emotional or ideological contrasts rather than serving only exposition?

Respond in JSON with your concise reasoning under 'reasoning' and your final answer under 'final_answer' being either 1 or 2 depending on which story has better dialogue and relationships""",
    "opening_and_ending_effectiveness": """Evaluate the following stories for opening and ending effectiveness.

Story 1: {story1}

Story 2: {story2}

Does the story begin with an engaging or distinctive hook that draws the reader in? Does the ending feel earned, emotionally satisfying, or thought-provoking? Does it avoid abrupt closure or moral summarization while providing resolution or meaningful ambiguity?

Respond in JSON with your concise reasoning under 'reasoning' and your final answer under 'final_answer' being either 1 or 2 depending on which story has better opening and ending effectiveness""",
    "narrative_voice_and_style": """Evaluate the following stories for narrative voice and style.

Story 1: {story1}

Story 2: {story2}

Is the narrative voice consistent and well-suited to the story's tone and genre? Does the prose demonstrate rhythm, clarity, and control without unnecessary ornamentation? Are point of view and stylistic choices coherent throughout?

Respond in JSON with your concise reasoning under 'reasoning' and your final answer under 'final_answer' being either 1 or 2 depending on which story has better narrative voice and style""",
    "originality_and_unresolved_mystery": """Evaluate the following stories for originality and unresolved mystery.

Story 1: {story1}

Story 2: {story2}

Does the story present fresh ideas, unique imagery, or unexpected developments? Does it avoid clichés or predictable arcs? Does it leave readers with lingering curiosity, reflection, or an element of mystery that deepens the impact of the ending?

Respond in JSON with your concise reasoning under 'reasoning' and your final answer under 'final_answer' being either 1 or 2 depending on which story has better originality and unresolved mystery""",
}


def extract_prompts_and_stories(prompts, completions, model_name):
    extracted_prompts = extract_prompts(prompts, model_name)
    if "Qwen3" in model_name:
        return extracted_prompts, completions
    else:
        return extracted_prompts, [
            completion[0]["content"] for completion in completions
        ]


def extract_prompts(prompts, model_name):
    if "Qwen3" in model_name:
        questions = []
        for prompt in prompts:
            # Extract question from prompt
            user_start = prompt.find("<|im_start|>user\n")
            user_end = prompt.find("<|im_end|>", user_start)
            if user_start != -1 and user_end != -1:
                questions.append(
                    prompt[user_start + len("<|im_start|>user\n") : user_end].strip()
                )
            else:
                # Fallback: try to find it after "user" marker or just use the whole prompt
                print("Warning: Could not parse question from prompt, using fallback")
                questions.append(prompt)
        return questions
    else:
        return [prompt[-1]["content"] for prompt in prompts]


def extract_validity(response_text):
    """Extract validity from structured JSON response."""

    # Parse JSON
    data = json.loads(response_text)
    final_answer = data["final_answer"]
    return int(final_answer) - 1


def pick_qwen_grader(prompts: List[str], model_name: str):
    if "Qwen3" in model_name:
        return qwen3_grader(prompts)
    else:
        return qwen25_grader(prompts)


def evaluate_with_criterion_parallel(
    questions: List[str],
    stories: List[str],
    active_criteria: List[str],
    model_name: str,
) -> List[float]:
    """Evaluate story with multiple criteria in parallel."""
    max_retries = 3

    # compute penalties for losses in pair comparisons based on criterion index
    num_criteria = len(active_criteria)
    max_penalty = 0.5  # Losing first criterion costs half a win
    scale = max_penalty / (num_criteria - 1) if num_criteria > 1 else 0.0
    loss_penalties = {
        c: (num_criteria - i - 1) * scale for i, c in enumerate(active_criteria)
    }
    print(f"Loss penalties: {loss_penalties}")

    # Prepare all prompts
    prompts, pair_comparisons, criteria_list = [], [], []

    for criterion in active_criteria:
        template = EVALUATION_CRITERIA[criterion]

        for i in range(len(questions)):
            for j in range(i):
                # Randomize which story appears first to mitigate position bias
                if random.random() < 0.5:
                    prompt = template.format(
                        question1=questions[i], story1=stories[i], story2=stories[j]
                    )
                    pair_comparisons.append((i, j))
                else:
                    prompt = template.format(
                        question1=questions[j], story1=stories[j], story2=stories[i]
                    )
                    pair_comparisons.append((j, i))
                prompts.append(prompt)
                criteria_list.append(criterion)

    # Pass as batch to grader
    scores = [0.0 for _ in questions]

    for _ in range(max_retries):
        if not prompts:
            break

        responses = pick_qwen_grader(prompts, model_name)

        retry_prompts, retry_pairs, retry_criteria = [], [], []
        for response, prompt, pair, criterion in zip(
            responses, prompts, pair_comparisons, criteria_list
        ):
            try:
                winner = extract_validity(response)
                scores[pair[winner]] += 1.0
                scores[pair[1 - winner]] -= loss_penalties[criterion]
                print(f"  {criterion}: {pair[winner]} beats {pair[1-winner]}")
            except Exception:
                print(f"Warning: {criterion} response not well-formatted")
                retry_prompts.append(prompt)
                retry_pairs.append(pair)
                retry_criteria.append(criterion)

        prompts, pair_comparisons, criteria_list = (
            retry_prompts,
            retry_pairs,
            retry_criteria,
        )

    # Handle remaining failures
    for pair, criterion in zip(pair_comparisons, criteria_list):
        print(
            f"  {criterion}: Defaulting to Story {pair[0]} in pair {pair} after {max_retries} retries"
        )
        scores[pair[0]] += 1.0
        scores[pair[1]] -= loss_penalties[criterion]

    num_comparisons = (len(stories) - 1) * num_criteria
    rewards = [max(0.0, s) / num_comparisons for s in scores]
    return rewards


def make_curriculum_reward_function(state, model_name):
    def curriculum_reward_function(prompts, completions, **kwargs) -> List[float]:
        """Calculate rewards for generator outputs using Qwen grading."""
        active_criteria = state.get_active_criteria()

        print(
            f"\nStep {state.current_step}: Using {len(active_criteria)} criteria: {active_criteria}"
        )

        questions, stories = extract_prompts_and_stories(
            prompts, completions, model_name
        )

        print("QUESTION:")
        print(questions[0])
        print(f"\nEvaluating stories for question: {questions[0][:100]}...")

        for i, story in enumerate(stories):
            print(f"Story {i}: {story[:200]}")

        # Evaluate with active criteria
        rewards = evaluate_with_criterion_parallel(
            questions, stories, active_criteria, model_name
        )

        print(f"Curriculum rewards: {rewards}")

        # Update step counter
        state.current_step += 1

        # Log to wandb
        avg_reward = sum(rewards) / len(rewards)
        if wandb.run:
            wandb.log(
                {
                    "step": state.current_step,
                    "avg_reward": avg_reward,
                    "num_criteria": len(active_criteria),
                    "active_criteria": ", ".join(active_criteria),
                }
            )

        print(f"\nAverage reward: {avg_reward:.3f}")
        return rewards

    return curriculum_reward_function


def extract_completion_representation(last_layer, prompt_lengths, full_ids, tokenizer):
    indices = torch.arange(full_ids.shape[1], device=last_layer.device)

    p_lens = torch.tensor(prompt_lengths, device=last_layer.device).unsqueeze(1)

    # Mask = (Index >= Prompt Len) AND (Token is NOT Pad)
    completion_mask = (indices >= p_lens) & (full_ids != tokenizer.pad_token_id)

    mask_expanded = completion_mask.float().unsqueeze(-1)

    last_layer = last_layer.to(torch.float32)

    sum_hidden = (last_layer * mask_expanded).sum(dim=1)  # [num_rollouts, hidden_dim]

    # Count non-padding tokens per sequence
    sum_mask = completion_mask.sum(dim=1, keepdim=True).clamp(
        min=1.0
    )  # [num_rollouts, 1]

    mean_pooled = sum_hidden / sum_mask  # [num_rollouts, hidden_dim]

    mean_pooled = torch.nn.functional.normalize(mean_pooled, p=2, dim=1)

    return mean_pooled


def random_projection(reps: torch.Tensor, proj_dim: int):
    proj_mat = torch.randn(
        reps.shape[1], proj_dim, device=reps.device, dtype=torch.float32
    )
    proj_mat = torch.nn.functional.normalize(proj_mat, p=2, dim=0)
    h_bar = reps @ proj_mat  # [num_rollouts, proj_dim]

    return h_bar


def compute_repexp_bonuses(reps, lambda_reg=0.1) -> torch.Tensor:
    _, hidden_dim = reps.shape

    # Compute covariance matrix: Σ = λI + Σᵢ h̄ᵢ h̄ᵢᵀ
    summed_outer_products = einsum(reps, reps, "k d1, k d2 -> d1 d2")

    identity = torch.eye(hidden_dim, device=reps.device)
    covariance = lambda_reg * identity + summed_outer_products

    try:
        covariance_inv = torch.linalg.inv(covariance)
    except:
        covariance_inv = torch.linalg.pinv(covariance)

    # Compute elliptic bonus for each representation: h̄ᵀ Σ⁻¹ h̄
    bonuses = einsum(reps, covariance_inv, reps, "k d1, d1 d2, k d2 -> k")

    return bonuses


def make_diversity_reward_function(model, tokenizer, model_name):
    def diversity_reward_function(prompts, completion_ids, **kwargs) -> List[float]:
        device = model.device
        pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id

        full_seqs = []
        p_lengths = []

        prompt_texts = extract_prompts(prompts, model_name)

        for p_text, c_ids_raw in zip(prompt_texts, completion_ids):
            p_ids = tokenizer(
                p_text, return_tensors="pt", add_special_tokens=False
            ).input_ids[0]

            c_ids = torch.tensor(
                [t if 0 <= t < tokenizer.vocab_size else pad_id for t in c_ids_raw]
            )

            full_seqs.append(torch.cat([p_ids, c_ids]))
            p_lengths.append(len(p_ids))

        full_ids = pad_sequence(full_seqs, batch_first=True, padding_value=pad_id).to(
            device
        )

        with torch.no_grad():
            outputs = model(input_ids=full_ids, output_hidden_states=True)
            last_layer = outputs.hidden_states[
                -1
            ]  # [num_rollouts, max_prompt_len + max_completion_len, hidden_dim]

        # Extract representation from completion tokens only
        representations = extract_completion_representation(
            last_layer, p_lengths, full_ids, tokenizer
        )
        projected_reps = random_projection(representations, 256)

        # Compute RepExp bonuses
        bonuses = compute_repexp_bonuses(projected_reps)
        rewards = bonuses.cpu().tolist()

        rewards = normalize_rewards(rewards)

        print("Diversity reward: ", rewards)
        return rewards

    return diversity_reward_function


def compute_perplexity_batched(
    model, tokenizer, prefixes: List[str], completions: List[str], chunk_size: int = 8
) -> List[float]:
    """Compute log-perplexity for multiple prefix-completion pairs in chunked forward passes."""
    device = model.device
    pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id

    all_losses = []

    for chunk_start in range(0, len(prefixes), chunk_size):
        chunk_prefixes = prefixes[chunk_start : chunk_start + chunk_size]
        chunk_completions = completions[chunk_start : chunk_start + chunk_size]

        all_ids, prefix_lengths = [], []
        for prefix, completion in zip(chunk_prefixes, chunk_completions):
            prefix_ids = tokenizer(prefix, add_special_tokens=False).input_ids
            completion_ids = tokenizer(completion, add_special_tokens=False).input_ids
            all_ids.append(torch.tensor(prefix_ids + completion_ids))
            prefix_lengths.append(len(prefix_ids))

        full_ids = pad_sequence(all_ids, batch_first=True, padding_value=pad_id).to(
            device
        )
        attention_mask = (full_ids != pad_id).long()

        labels = full_ids.clone()
        for i, plen in enumerate(prefix_lengths):
            labels[i, :plen] = -100
        labels[full_ids == pad_id] = -100

        with torch.no_grad():
            logits = model(full_ids, attention_mask=attention_mask).logits

        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = labels[:, 1:].contiguous()

        per_token_loss = torch.nn.functional.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            reduction="none",
        ).view(full_ids.size(0), -1)

        valid_mask = (shift_labels != -100).float()
        per_sample_loss = (per_token_loss * valid_mask).sum(dim=1) / valid_mask.sum(
            dim=1
        ).clamp(min=1)

        all_losses.extend(per_sample_loss.tolist())

        # Free memory
        del logits, shift_logits, per_token_loss

    return all_losses


def normalize_rewards(rewards: List[float]) -> List[float]:
    """Normalize rewards to [0,1] range."""
    min_v, max_v = min(rewards), max(rewards)
    if max_v - min_v < 1e-6:
        return [0.5] * len(rewards)
    return [(r - min_v) / (max_v - min_v) for r in rewards]


def make_perplexity_reward_function(model, tokenizer, model_name):
    def perplexity_reward_function(prompts, completions, **kwargs) -> List[float]:
        """Reward based on log-perplexity (higher = more surprising/creative)."""
        prompts, completions = extract_prompts_and_stories(
            prompts, completions, model_name
        )
        rewards = compute_perplexity_batched(model, tokenizer, prompts, completions)
        rewards = normalize_rewards(rewards)
        print("Perplexity reward:", rewards)
        return rewards

    return perplexity_reward_function


def make_in_context_usefulness_reward_function(model, tokenizer, model_name):
    def in_context_usefulness_reward_function(
        prompts, completions, **kwargs
    ) -> List[float]:
        """Reward based on how much each rollout reduces perplexity of other rollouts."""
        prompts, completions = extract_prompts_and_stories(
            prompts, completions, model_name
        )
        prompt = prompts[0]
        n = len(completions)

        # Baselines: 1 forward pass
        baselines = compute_perplexity_batched(
            model, tokenizer, [prompt] * n, completions
        )

        # Context pairs: 1 forward pass for all n*(n-1) pairs
        prefixes, targets, indices = [], [], []
        for i in range(n):
            for j in range(n):
                if i != j:
                    prefixes.append(
                        prompt + "\nHere is an example story for you: " + completions[i]
                    )
                    targets.append(completions[j])
                    indices.append((i, j))

        context_losses = compute_perplexity_batched(model, tokenizer, prefixes, targets)

        # Aggregate reductions
        rewards = [0.0] * n
        for (i, j), loss in zip(indices, context_losses):
            rewards[i] += baselines[j] - loss

        rewards = normalize_rewards(rewards)
        print("In-context usefulness rewards:", rewards)
        return rewards

    return in_context_usefulness_reward_function
