import argparse
import csv
import json
import os
import random
import re
import sys
import time
from typing import Dict, List, Optional

import numpy as np
from datasets import load_dataset
from openai import AzureOpenAI
from tqdm import tqdm
import openai

api_version = "2024-02-15-preview"
config_dict: Dict[str, str] = {
    "api_key": "YOUR_OPENAI_API_KEY",
    "api_version": api_version,
    "azure_endpoint": "https://your-azure-openai-endpoint/",
}

persona_prompts = {
    "18-24_female": (
        "You are a woman aged 18–24. You're part of a generation raised on visual platforms like "
        "Instagram, TikTok, and YouTube Shorts. You instinctively know what kinds of video ads grab "
        "attention and stick—bold aesthetics, emotional authenticity, humor, pop culture references, "
        "and strong individuality.\n\nYou are given 5 example video ads and their memorability scores "
        "(on a 0–10 scale). You're shown a sixth ad described in text. Your task is to judge how "
        "**memorable** this video ad is likely to be for women your age. Consider whether it would "
        "stand out in a fast-scroll feed, be shared, or come to mind later.\n\nReturn:\nReason: "
        "[Explain what aspects of the ad make it memorable or forgettable]\nAnswer: [0–10] ← You must "
        "include this score."
    ),
    "18-24_male": (
        "You are a man aged 18–24. You're part of a generation immersed in fast digital content—Twitch, "
        "TikTok, YouTube, and memes. You naturally notice what makes video ads stick—humor, shock value," 
        " edgy style, cultural references, or visual flair.\n\nYou are given 5 example video ads and "
        "their memorability scores (on a 0–10 scale). You're shown a sixth ad in text form. Your task "
        "is to judge how **memorable** this video ad is likely to be for men your age. Would it stand "
        "out or be scrolled past?\n\nReturn:\nReason: [Your honest view on what makes it stand out or "
        "not]\nAnswer: [0–10] ← You must include this score."
    ),
    "25-34_female": (
        "You are a woman aged 25–34. You value storytelling in video ads that feels authentic, visually "
        "appealing, and aligned with your interests—career, wellness, relationships, or self-expression. "
        "You appreciate ads with both style and substance.\n\nYou are given 5 video ad examples and their "
        "memorability scores (on a 0–10 scale). You're shown a sixth ad described in text. Your task is "
        "to judge how **memorable** this ad is likely to be for women your age. Would it emotionally "
        "connect, inspire, or linger in your mind?\n\nReturn:\nReason: [Explain what makes it likely—"
        "or unlikely—to be remembered]\nAnswer: [0–10] ← You must include this score."
    ),
    "25-34_male": (
        "You are a man aged 25–34. You notice video ads that speak to ambition, tech, fitness, or personal "
        "growth—especially when presented with clarity, energy, and style.\n\nYou are given 5 video ads "
        "and their memorability scores (on a 0–10 scale). You're shown a sixth ad in text. Judge how "
        "**memorable** it is likely to be for men your age. Would it catch your attention and stay with "
        "you?\n\nReturn:\nReason: [Why you think the ad has staying power—or not]\nAnswer: [0–10] ← You "
        "must include this score."
    ),
}

# Strengthen response format requirements for all personas
for _k in persona_prompts:
    persona_prompts[_k] += (
        "\n\nIMPORTANT: Provide your response in exactly two lines:\n"
        "Reason: <brief justification>\n"
        "Answer: <numeric score 0-10>\n"
        "Include only one number (0-10) after 'Answer:' and no other numbers in the response."
    )

# --------------------------------------------------------------------------------------
# Embedding model configuration
# --------------------------------------------------------------------------------------
EMBED_MODEL = "text-embedding-3-small"

# Path for global diagnostics (created in output_dir later)
EMBED_DIAG_FILENAME = "embedding_diagnostics.txt"


def _get_embeddings(
    texts: List[str],
    batch_size: int = 96,
    diag_path: Optional[str] = None,
    max_retries: int = 5,
) -> List[Optional[np.ndarray]]:
    """Compute embeddings for *texts* using Azure OpenAI with robust retry logic.

    For each *batch_size* slice of *texts* we attempt up to *max_retries* times;
    on a 429 (rate-limit) or transient network error we back-off exponentially.

    Unlike the previous implementation we *do not* abort the entire run when a
    single batch fails.  Instead, we insert ``None`` placeholders for every text
    in the failed batch so downstream code can still leverage the successful
    embeddings it has and fall back to random sampling only for the missing
    ones.
    """

    # Pre-allocate output so ordering is preserved even with failures.
    results: List[Optional[np.ndarray]] = [None] * len(texts)

    # Build Azure client once (outside the loop for efficiency)
    try:
        client = AzureOpenAI(
            api_key=config_dict["api_key"],
            api_version=config_dict["api_version"],
            azure_endpoint=config_dict["azure_endpoint"],
        )
    except Exception as e:
        if diag_path:
            with open(diag_path, "a") as f:
                f.write(f"[CLIENT-INIT-ERROR] Failed to create AzureOpenAI client: {e}\n")
        return results  # all None

    for start in range(0, len(texts), batch_size):
        chunk = texts[start : start + batch_size]

        # Retry loop for this chunk only
        attempt = 0
        while attempt <= max_retries:
            try:
                resp = client.embeddings.create(model=EMBED_MODEL, input=chunk)
                resp.data.sort(key=lambda x: x.index)  # preserve original order

                for i, d in enumerate(resp.data):
                    results[start + i] = np.array(d.embedding, dtype=np.float32)

                break  # success, move to next batch

            except Exception as e:
                attempt += 1
                # Parse retry-after seconds if available (Azure puts it in the
                # error message sometimes) – default to 5 * attempt seconds.
                wait_secs = 5 * attempt
                if "retry" in str(e).lower():
                    # crude extraction of the first integer in the message
                    import re as _re

                    m = _re.search(r"retry after (\d+)", str(e).lower())
                    if m:
                        wait_secs = int(m.group(1))

                if diag_path:
                    with open(diag_path, "a") as f:
                        f.write(
                            f"[EMBEDDING-ERROR] Batch {start}-{start+len(chunk)-1} attempt {attempt}/{max_retries}: {e}. Waiting {wait_secs}s.\n"
                        )

                if attempt > max_retries:
                    # Give up on this batch – leave None placeholders.
                    break

                time.sleep(wait_secs)

        # End retry loop

        # Respect base rate-limit between *successful* calls only.
        if attempt == 0 or results[start] is not None:
            time.sleep(0.25)  # more conservative than 0.1s previously

    return results


def _cosine(u: np.ndarray, v: np.ndarray) -> float:
    """Cosine similarity between two vectors."""
    return float(np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v) + 1e-8))

# --------------------------------------------------------------------------------------
# Helper functions
# --------------------------------------------------------------------------------------

def verbalize(user_prompt: str, sys_prompt: str) -> str:
    """Call Azure OpenAI chat completion and return the assistant message."""
    client = AzureOpenAI(
        api_key=config_dict["api_key"],
        api_version=config_dict["api_version"],
        azure_endpoint=config_dict["azure_endpoint"],
    )
    resp = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": user_prompt},
        ],
        temperature=0.85,
        max_tokens=350,
    )
    return resp.choices[0].message.content.strip()

# --------------------------------------------------------------------------------------
# CLI
# --------------------------------------------------------------------------------------

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Run static-persona evaluation over a slice of the LAMBDA stories_test dataset.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="static_folder", help="Directory to write JSON results.")
    parser.add_argument("--full_eval", action="store_true", help="(Optional) run full evaluation over the entire set; not used by the parallel runner.")
    parser.add_argument("--similarity_json", type=str, default=None, help="Path to JSON produced by similarityengine.py containing pre-computed 5-nearest neighbours per video.")
    return parser.parse_args()

# --------------------------------------------------------------------------------------
# Main evaluation logic (chunk mode)
# --------------------------------------------------------------------------------------

def main() -> None:
    args = parse_args()

    # If the user explicitly asks for full evaluation, fallback to the old behaviour.
    if args.full_eval:
        print("--full_eval requested; please revert to the previous version of the script.")
        sys.exit(1)

    start_idx: int = max(0, args.start)
    output_dir: str = args.output_dir
    os.makedirs(output_dir, exist_ok=True)

    # Optional: load pre-computed similarity map (video_id -> list of 5 neighbour dicts)
    sim_map: Optional[Dict[str, List[Dict]]] = None
    if args.similarity_json:
        if not os.path.isfile(args.similarity_json):
            print(f"[WARNING] --similarity_json provided but file not found: {args.similarity_json}")
        else:
            with open(args.similarity_json, "r", encoding="utf-8") as _f:
                sim_map = json.load(_f)
            print(f"Loaded pre-computed similarity map from {args.similarity_json} (entries: {len(sim_map)})")

    # -----------------------------------------------------------------------------
    # Load datasets
    # -----------------------------------------------------------------------------
    cache_root = "/path/to/hf_cache"

    lambda_dataset = load_dataset(
        "behavior-in-the-wild/LAMBDA",
        split="test",
        cache_dir=cache_root,
    )

    # Load stories CSV
    STORIES_CSV_PATH = "/path/to/stories_test.csv"
    story_rows: List[Dict] = []
    with open(STORIES_CSV_PATH, newline="", encoding="utf-8") as csv_f:
        reader = csv.DictReader(csv_f)
        for row in reader:
            try:
                row["video_id"] = int(row["video_id"])
                story_rows.append(row)
            except Exception:
                continue  # skip malformed rows

    # Map video_id -> record from HF dataset for quick lookup
    dataset_map: Dict[int, Dict] = {int(r["video_id"]): r for r in lambda_dataset}

    # Build working records list (retain global index for embedding lookup)
    all_records: List[Dict] = []
    for row in story_rows:
        vid = row["video_id"]
        if vid not in dataset_map:
            continue  # require recall score
        rec = dataset_map[vid]
        all_records.append({
            "video_id": vid,
            "story": row["story"],
            "recall_score": rec["recall_score"],
            "global_idx": len(all_records),
        })

    # Map video_id -> index in all_records for O(1) lookup when using sim_map
    vid_to_idx_map: Dict[int, int] = {rec["video_id"]: i for i, rec in enumerate(all_records)}

    # -----------------------------------------------------------------------------
    # Compute embeddings for every story (may take a minute)
    # -----------------------------------------------------------------------------
    # Diagnostics file to capture embedding-related errors
    diag_path = os.path.join(output_dir, EMBED_DIAG_FILENAME)

    if sim_map is None:
        # No pre-computed neighbours – fall back to on-the-fly embedding generation
        _story_texts = [rec["story"] for rec in all_records]
        try:
            all_embeddings = _get_embeddings(_story_texts, diag_path=diag_path)
        except Exception as e:
            print(f"Embedding generation failed; falling back to random sampling: {e}")
            with open(diag_path, "a") as f:
                f.write(f"[GENERAL-ERROR] {e}\n")
            all_embeddings = [None] * len(all_records)
    else:
        # We won't need embeddings because we'll use pre-computed neighbours.
        all_embeddings = [None] * len(all_records)

    # -----------------------------------------------------------------------------
    # Slice records for this worker
    # -----------------------------------------------------------------------------
    end_idx = args.end if args.end is not None else len(all_records) - 1
    end_idx = min(end_idx, len(all_records) - 1)
    slice_records = all_records[start_idx:end_idx + 1]
    print(f"Processing slice {start_idx}–{end_idx} (n={len(slice_records)})")

    # -----------------------------------------------------------------------------
    # Main loop
    # -----------------------------------------------------------------------------
    results: List[Dict] = []

    for record in tqdm(slice_records, desc="Static-persona eval"):
        target_text = record["story"]

        # If we have a pre-computed neighbour map, fetch the corresponding indices now.
        precomputed_ids = None
        if sim_map and str(record["video_id"]) in sim_map:
            neighbour_dicts = sim_map[str(record["video_id"])]
            precomputed_ids = [vid_to_idx_map.get(nb["video_id"]) for nb in neighbour_dicts]
            # Drop any None entries and truncate to 5 just in case
            precomputed_ids = [pid for pid in precomputed_ids if pid is not None][:5]

        # --------------------------------------------------------------
        # Determine neighbour examples
        # --------------------------------------------------------------
        if sim_map is not None and precomputed_ids:
            # Use the pre-computed neighbours exclusively (no random fallback)
            sample_ids = precomputed_ids

            # Basic logging so we can audit which neighbours were used
            similarity_log_path = os.path.join(output_dir, f"similarity_scores_{start_idx}_{end_idx}.txt")
            with open(similarity_log_path, 'a') as log_file:
                log_file.write(f"Processing video_id: {record['video_id']}\n")
                log_file.write("Using pre-computed neighbour IDs:\n")
                for j in sample_ids:
                    log_file.write(f"  video_id {all_records[j]['video_id']}\n")
                log_file.write("\n")
        else:
            # --------------------------------------------------------------
            # Select 5 examples via semantic similarity (fallback -> random)
            # --------------------------------------------------------------
            idx_global = record["global_idx"]
            other_ids = list(range(len(all_records)))
            other_ids.remove(idx_global)

            SIMILARITY_THRESHOLD = 0.5 # Adjust as needed

            # Log similarity scores to a file
            similarity_log_path = os.path.join(output_dir, f"similarity_scores_{start_idx}_{end_idx}.txt")
            with open(similarity_log_path, 'a') as log_file:
                log_file.write(f"Processing video_id: {record['video_id']}\n")
                if all_embeddings[idx_global] is not None:
                    # Only compute similarities for non-None embeddings
                    valid_pairs = []
                    for j in other_ids:
                        if all_embeddings[j] is not None:
                            sim = _cosine(all_embeddings[idx_global], all_embeddings[j])
                            valid_pairs.append((sim, j))
                            log_file.write(f"Similarity with video_id {all_records[j]['video_id']}: {sim}\n")
                    
                    if len(valid_pairs) >= 5:
                        # Filter by threshold
                        above_thresh = [(sim, j) for sim, j in valid_pairs if sim >= SIMILARITY_THRESHOLD]
                        if len(above_thresh) >= 5:
                            # Pick top 5 above threshold
                            above_thresh.sort(reverse=True)
                            sample_ids = [j for _, j in above_thresh[:5]]
                            log_file.write("Selected top 5 above threshold:\n")
                            for sim, j in above_thresh[:5]:
                                log_file.write(f"  video_id {all_records[j]['video_id']}: {sim}\n")
                        elif len(above_thresh) > 0:
                            # Fewer than 5 above threshold, use only those
                            sample_ids = [j for _, j in above_thresh]
                            log_file.write("Selected all above threshold (less than 5):\n")
                            for sim, j in above_thresh:
                                log_file.write(f"  video_id {all_records[j]['video_id']}: {sim}\n")
                        else:
                            # No examples above threshold, fallback to top 5 most similar
                            valid_pairs.sort(reverse=True)
                            sample_ids = [j for _, j in valid_pairs[:5]]
                            log_file.write("Fallback to top 5 most similar:\n")
                            for sim, j in valid_pairs[:5]:
                                log_file.write(f"  video_id {all_records[j]['video_id']}: {sim}\n")
                    else:
                        # Not enough valid embeddings for similarity-based selection
                        # Fall back to random sampling from valid indices
                        valid_indices = [j for j in other_ids if all_embeddings[j] is not None]
                        if len(valid_indices) >= 5:
                            sample_ids = random.sample(valid_indices, k=5)
                            log_file.write("Fallback to random sampling from valid indices (5 or more available):\n")
                            for j in sample_ids:
                                log_file.write(f"  video_id {all_records[j]['video_id']}\n")
                        else:
                            # Use all valid indices if fewer than 5
                            sample_ids = valid_indices
                            log_file.write("Fallback to using all valid indices (less than 5):\n")
                            for j in sample_ids:
                                log_file.write(f"  video_id {all_records[j]['video_id']}\n")
                else:
                    # Target embedding is None, fall back to random sampling
                    sample_ids = random.sample(other_ids, k=min(5, len(other_ids)))
                    log_file.write("Target embedding is None, fallback to random sampling:\n")
                    for j in sample_ids:
                        log_file.write(f"  video_id {all_records[j]['video_id']}\n")
                log_file.write("\n")

        example_blocks = []
        for sid in sample_ids:
            ex = all_records[sid]
            text = ex["story"]
            score = round(float(ex["recall_score"]) * 10, 2)
            example_blocks.append(f"{text}\nScore: {score}")
        examples_text = "\n---\n".join(example_blocks)

        persona_predictions: Dict[str, Dict] = {}
        for persona_name, sys_prompt in persona_prompts.items():
            user_prompt = (
                "Below are five example video ads and their memorability scores. "
                "After these, you'll see a new ad. Give its long-term memorability "
                "score (0–10).\n\n" + examples_text + "\n---\n" + target_text + "\n\n" +
                "Question: What is the long-term memorability score of this video (0-10)?"
            )
            resp_text = verbalize(user_prompt, sys_prompt)
            numbers = re.findall(r"\b\d+(?:\.\d+)?\b", resp_text)
            pred = float(numbers[-1]) if numbers else None
            # Clamp prediction to [0, 10]
            if pred is not None:
                pred = max(0, min(10, pred))
            persona_predictions[persona_name] = {
                "prediction": pred,
                "response": resp_text,
                "prompt": user_prompt,
            }

        # Compute mean prediction across all personas (ignoring None)
        valid_preds = [v["prediction"] for v in persona_predictions.values() if v["prediction"] is not None]
        mean_pred = sum(valid_preds) / len(valid_preds) if valid_preds else None

        results.append({
            "video_id": int(record["video_id"]),
            "ground_truth": float(record["recall_score"]) * 10,
            "predictions": persona_predictions,
            "mean_prediction": mean_pred,
        })

        # Dump every 5 examples to avoid data loss on pre-emption
        if len(results) % 5 == 0:
            out_path = os.path.join(output_dir, f"static_results_{start_idx}_{end_idx}.json")
            with open(out_path, "w") as f:
                json.dump(results, f, indent=2)

    # Final save
    out_path = os.path.join(output_dir, f"static_results_{start_idx}_{end_idx}.json")
    with open(out_path, "w") as f:
        json.dump(results, f, indent=2)


if __name__ == "__main__":
    main()


            

            

        
        
        

        
