import re
import random
from typing import Any, Dict, List, Optional, Tuple

import cv2
import numpy as np
from PIL import Image
from qwen_vl_utils import process_vision_info
from scipy.stats import spearmanr


def extract_task_completion_percentages(text: str, maxlen: int) -> List[float]:
    """Extract percentage values like "42%" from text.

    Returns up to the last `maxlen` values as floats.
    """
    pattern = r"(\d+(?:\.\d+)?)%"
    matches = re.findall(pattern, text)
    values: List[float] = [float(value) for value in matches]
    if maxlen > 0 and len(values) > maxlen:
        return values[-maxlen:]
    return values


def extract_frame_progress_percentages(text: str, expected_k: int) -> Optional[List[int]]:
    """Parse lines of the form:
    "Frame i: ... Relative Rank: r" for i in 1..K.

    Returns a list of length K with ranks or None if parsing fails.
    """
    found: Dict[int, int] = {}

    frame_rank_re = re.compile(
        r"^\s*Frame\s+(\d+):\s*.*?Relative\s+Rank:\s*(\d+)\s*$",
        re.IGNORECASE | re.MULTILINE,
    )

    for match in frame_rank_re.finditer(text):
        try:
            idx = int(match.group(1))  # frame index in shown order: 1..K
            rank = int(match.group(2))  # claimed rank: 1..K
            if 1 <= idx <= expected_k and 1 <= rank <= expected_k:
                found[idx] = rank
        except (ValueError, IndexError):
            continue  # skip malformed lines

    # Must have exactly K lines parsed
    if len(found) != expected_k:
        return None

    # Assemble ranks aligned to Frame 1..K
    try:
        ranks = [found[i] for i in range(1, expected_k + 1)]
    except KeyError:
        return None

    return ranks


def shuffle_except_first(lst: List[Any], rng: Optional[random.Random] = None) -> Tuple[List[Any], List[int]]:
    """Shuffle a list except the first element, return shuffled list and indices used.

    Uses a provided local RNG to avoid coupling with global random state.
    """
    if len(lst) <= 1:
        return lst[:], []

    indices = list(range(1, len(lst)))
    rng = rng or random

    shuffled_indices = indices[:]
    rng.shuffle(shuffled_indices)
    shuffled_part = [lst[i] for i in shuffled_indices]
    shuffled_list = [lst[0]] + shuffled_part

    return shuffled_list, shuffled_indices


def get_gvl_content_v2(
    frames: List[Any],
    success_frames: Optional[List[Any]] = None,
    allow_ties: bool = False,
) -> List[Dict[str, Any]]:

    content: List[Dict[str, Any]] = []
    all_images: List[Any] = []
    k: int = len(frames) - 1  # number of frames to rank (1..k)

    instruction = "You are an expert game analyst. "
    if success_frames:
        instruction += "You will first see reference frames that DEFINE SUCCESS. Then "
    instruction += (
        "You will see frames from ONE gameplay segment. "
        "The FIRST image is the earliest frame (anchor) — do NOT rank it. "
        "The anchor provides the starting-state context (baseline) for this segment. "
        f"Next, you will see K={k} frames in RANDOM order from the SAME segment.\n\n"
        "TASK: For these K frames, assign RELATIVE RANKS 1..K by CLOSENESS TO SUCCESS "
        "(1 = least close to success, K = most close to success). "
    )
    instruction += (
        "Use each rank exactly once (no duplicates, no ties).\n\n"
        if not allow_ties
        else "Ranks may repeat; ties are allowed.\n\n"
    )
    content.append({"type": "text", "text": instruction})

    if success_frames:
        content.append({"type": "text", "text": "SUCCESS references (do NOT rank):"})
        for img in success_frames:
            # If items are (image, percent) tuples, take only the image
            if isinstance(img, tuple) and len(img) == 2:
                img = img[0]
            content.append({"type": "image", "image": img})
            all_images.append(img)
        content.append({"type": "text", "text": "\n\n"})

    content.append(
        {
            "type": "text",
            "text": "Anchor (earliest frame — baseline for context; implicit rank = 0, do NOT rank):",
        }
    )
    content.append({"type": "image", "image": frames[0]})
    all_images.append(frames[0])
    content.append({"type": "text", "text": "\nDo not assign any rank to this frame.\n\n"})

    hard_rules: List[str] = []
    hard_rules.append("• Use the anchor only as context/baseline; rank only frames 1..K.")
    if not allow_ties:
        hard_rules.append(f"• Use each rank in 1..{k} EXACTLY ONCE (no duplicates, no ties).")
    else:
        hard_rules.append(f"• Ranks in 1..{k} may be duplicated (ties are allowed).")
    hard_rules.append("• The description must not be empty.")
    hard_rules.append("• Keep your answer strictly to the required format, one line per frame.")

    content.append(
        {
            "type": "text",
            "text": (
                "Now, output the frame description and RELATIVE RANK for the following frames, "
                "which are presented in RANDOM order within this segment.\n"
                "HARD RULES:\n" + "\n".join(hard_rules) + "\n\nUse EXACTLY this format for EACH frame:\n"
                "```\nFrame {i}: Frame Description: {}, Relative Rank: {r}\n```\n\n"
            ),
        }
    )

    for idx, img in enumerate(frames[1:], 1):
        content.append({"type": "text", "text": f"Frame {idx}:"})
        content.append({"type": "image", "image": img})
        content.append({"type": "text", "text": "\n"})
        all_images.append(img)

    num_image_placeholders = sum(1 for item in content if item.get("type") == "image")
    if num_image_placeholders != len(all_images):
        raise ValueError(
            f"Number of image placeholders ({num_image_placeholders}) doesn't match image count ({len(all_images)})"
        )
    messages = [{"role": "user", "content": content}]
    return messages


def get_gvl_content(
    frames: List[Any],
    example_frames: List[Tuple[Any, float]],
    use_percentage: bool = False,
    is_shuffled_context: bool = True,
) -> List[Dict[str, Any]]:
    """Build v1 prompt content with optional example frames and percentages."""
    content: List[Dict[str, Any]] = []

    instruction = '''
        You are an expert game analyst tasked with predicting task completion percentages for frames from the Catrap game, 
        where the player character is solving a level involving enemies, puzzles, and obstacles.
        The task completion percentages are between 0 and 100, where 100
        corresponds to full task completion.
    '''

    if len(example_frames) > 0:
        if use_percentage:
            add = ", along with their corresponding task completion percentages"
        else:
            add = ""

        instruction += f'''We provide several example frames showing various stages of the gameplay{add}.'''

        if is_shuffled_context:
            instruction += "Note that these frames are in random order, so please pay attention to the individual frames when reasoning about task completion percentage."
        
    content.append({"type": "text", "text": instruction})

    for idx, (img, percent) in enumerate(example_frames, 1):
        content.append({"type": "text", "text": f"Example Frame {idx}:"})
        content.append({"type": "image", "image": img})
        if use_percentage:
            content.append({"type": "text", "text": f"In this frame, the task completion percentage is {percent}%."})


    content.append({"type": "text", "text": "Initial game state: "})
    content.append({"type": "image", "image": frames[0]})


    content.append({"type": "text", "text": (
        f"Now, for the level being solved in Catrap, output the task completion percentage for the following "
        f"frames that are presented in random order. For each frame, format your response as follows:\n"
        f"```\nFrame {{i}}: Frame Description: {{}}, Task Completion Percentage: {{}}%\n```"
    )})

    for idx, img in enumerate(frames[1:], 1):
        content.append({"type": "text", "text": f"Frame {idx}:"})
        content.append({"type": "image", "image": img})

    messages = [
        {
            "role": "user",
            "content": content,
        }
    ]
    return messages


def get_video_frames(video_path: str, crop_box: Tuple[int, int, int, int] = (0, 0, 640, 330)) -> List[Image.Image]:
    """Load video frames as PIL images cropped to `crop_box` (converting BGR→RGB)."""
    cap = cv2.VideoCapture(video_path)

    if not cap.isOpened():
        raise ValueError("Error: Cannot open video.")

    all_frames: List[Image.Image] = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        # Convert BGR (OpenCV) to RGB before creating PIL Image
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        all_frames.append(Image.fromarray(frame_rgb).crop(crop_box))

    cap.release()
    return all_frames


def voc_reward_fn(
    frames: List[Image.Image],
    model: Any,
    processor: Any,
    n_repeats: int = 1,
    context_len: int = -1,
    shuffle_context: bool = False,
    use_percentage: bool = False,
    context_path: Optional[str] = None,
    context_offset: int = 0,
    crop_box: Tuple[int, int, int, int] = (0, 0, 640, 330),
    prompt_version: str = "v2",
    rng_seed: Optional[int] = None,
) -> Tuple[float, np.ndarray]:
    """Compute VOC by querying the model with either v1 (percent) or v2 (rank) prompts.

    Returns mean VOC and per-repeat VOC array.
    """
    total_voc = 0.0
    all_vocs: List[float] = []

    if context_len > 0:
        all_context_frames = get_video_frames(context_path, crop_box)[context_offset:]

        # # Select n_frames_context evenly spaced frames from all_context_frames
        if len(all_context_frames) > 0 and context_len > 0:
            if len(all_context_frames) <= context_len:
                indices = list(range(len(all_context_frames)))
                raw_selected = [all_context_frames[i] for i in indices]
            else:
                indices = np.linspace(0, len(all_context_frames) - 1, context_len).astype(int).tolist()
                raw_selected = [all_context_frames[i] for i in indices]

            # Pair each selected frame with a percentage (0..100) based on its absolute
            # index position within all_context_frames.
            denom = max(1, len(all_context_frames) - 1)
            if prompt_version == "v1" and use_percentage:
                selected_context_frames = [
                    (img, round(100.0 * idx / denom, 2)) for img, idx in zip(raw_selected, indices)
                ]
            else:
                selected_context_frames = [img for img in raw_selected]
        else:
            selected_context_frames = []

    else:
        selected_context_frames = None


    local_rng = random.Random(rng_seed) if rng_seed is not None else None

    for rep in range(n_repeats):
        # Derive a per-repeat rng for more variability but deterministic given seed
        rep_rng = random.Random((rng_seed, rep)) if rng_seed is not None else local_rng
        shuffled_frames, shuffled_indices = shuffle_except_first(frames, rng=rep_rng)

        if shuffle_context and selected_context_frames is not None:
            (rep_rng or random).shuffle(selected_context_frames)

        if prompt_version == "v1":
            messages = get_gvl_content(shuffled_frames, selected_context_frames or [], use_percentage, shuffle_context)
        else:
            messages = get_gvl_content_v2(shuffled_frames, selected_context_frames, allow_ties=False)

        text = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, _ = process_vision_info(messages)
        inputs = processor(
            text=[text],
            images=image_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(model.device)

        # Inference: Generation of the output
        generated_ids = model.generate(**inputs, max_new_tokens=1536, do_sample=False)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )

        if prompt_version == "v1":
            extracted = extract_task_completion_percentages(output_text[0], len(frames) - 1)
        else:
            extracted = extract_frame_progress_percentages(output_text[0], len(frames) - 1)

        if extracted is None:
            unordered_values = None
        else:
            unordered_values = [-1.0] + extracted
        print(unordered_values)

        if unordered_values is not None and len(unordered_values) == len(frames):
            values_order = np.argsort(unordered_values)
            order = np.array([0] + shuffled_indices)[values_order]
            true_order = np.arange(len(order))

            # Compute VOC here; slice makes it unbiased
            voc, _ = spearmanr(order[1:] - 1, true_order[1:] - 1)
        else:
            voc = -1

        all_vocs.append(voc)
        if voc is not None:
            total_voc += float(voc)
        else:
            total_voc += -1.0

    all_vocs_arr = np.array(all_vocs)
    return float(np.mean(all_vocs_arr)), all_vocs_arr
