import argparse
import csv
import json
import os
import random
import re
import sys
import time
from typing import Dict, List, Optional
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
model, tokenizer = None, None
# Persona prompts
persona_prompts = {
    "18-24_female": (
        "You are a woman aged 18–24. You're part of a generation raised on visual platforms like "
        "Instagram, TikTok, and YouTube Shorts. You instinctively know what kinds of video ads grab "
        "attention and stick—bold aesthetics, emotional authenticity, humor, pop culture references, "
        "and strong individuality.\n\nYou are given 5 example video ads and their memorability scores "
        "(on a 0–10 scale). You're shown a sixth ad described in text. Your task is to judge how "
        "**memorable** this video ad is likely to be for women your age. Consider whether it would "
        "stand out in a fast-scroll feed, be shared, or come to mind later.\n\nReturn:\nReason: "
        "[Explain what aspects of the ad make it memorable or forgettable]\nAnswer: [0–10] ← You must "
        "include this score."
    ),
    "18-24_male": (
        "You are a man aged 18–24. You're part of a generation immersed in fast digital content—Twitch, "
        "TikTok, YouTube, and memes. You naturally notice what makes video ads stick—humor, shock value," 
        " edgy style, cultural references, or visual flair.\n\nYou are given 5 example video ads and "
        "their memorability scores (on a 0–10 scale). You're shown a sixth ad in text form. Your task "
        "is to judge how **memorable** this video ad is likely to be for men your age. Would it stand "
        "out or be scrolled past?\n\nReturn:\nReason: [Your honest view on what makes it stand out or "
        "not]\nAnswer: [0–10] ← You must include this score."
    ),
    "25-34_female": (
        "You are a woman aged 25–34. You value storytelling in video ads that feels authentic, visually "
        "appealing, and aligned with your interests—career, wellness, relationships, or self-expression. "
        "You appreciate ads with both style and substance.\n\nYou are given 5 video ad examples and their "
        "memorability scores (on a 0–10 scale). You're shown a sixth ad described in text. Your task is "
        "to judge how **memorable** this ad is likely to be for women your age. Would it emotionally "
        "connect, inspire, or linger in your mind?\n\nReturn:\nReason: [Explain what makes it likely—"
        "or unlikely—to be remembered]\nAnswer: [0–10] ← You must include this score."
    ),
    "25-34_male": (
        "You are a man aged 25–34. You notice video ads that speak to ambition, tech, fitness, or personal "
        "growth—especially when presented with clarity, energy, and style.\n\nYou are given 5 video ads "
        "and their memorability scores (on a 0–10 scale). You're shown a sixth ad in text. Judge how "
        "**memorable** it is likely to be for men your age. Would it catch your attention and stay with "
        "you?\n\nReturn:\nReason: [Why you think the ad has staying power—or not]\nAnswer: [0–10] ← You "
        "must include this score."
    ),
}
# Strengthen response format requirements
for _k in persona_prompts:
    persona_prompts[_k] += (
        "\n\nIMPORTANT: Provide your response in exactly two lines:\n"
        "Reason: <brief justification>\n"
        "Answer: <numeric score 0-10>\n"
        "Include only one number (0-10) after 'Answer:' and no other numbers in the response."
    )
# Embedding model configuration
EMBED_MODEL = "text-embedding-3-small"
EMBED_DIAG_FILENAME = "embedding_diagnostics.txt"
# Azure OpenAI Configuration
api_version = "2024-02-15-preview"
config_dict: Dict[str, str] = {
    "api_key": "YOUR_OPENAI_API_KEY",
    "api_version": api_version,
    "azure_endpoint": "https://your-azure-openai-endpoint/",
}
def _get_embeddings(
    texts: List[str], batch_size: int = 96, diag_path: Optional[str] = None, max_retries: int = 5
) -> List[Optional[np.ndarray]]:
    """Compute embeddings via Azure OpenAI with robust retry logic."""
    # Pre-allocate output so ordering is preserved even with failures.
    results: List[Optional[np.ndarray]] = [None] * len(texts)

    # Build Azure client once (outside the loop for efficiency)
    try:
        client = AzureOpenAI(
            api_key=os.getenv("OPENAI_API_KEY", config_dict["api_key"]),
            api_version=config_dict["api_version"],
            azure_endpoint=config_dict["azure_endpoint"],
        )
    except Exception as e:
        if diag_path:
            with open(diag_path, "a") as f:
                f.write(f"[CLIENT-INIT-ERROR] Failed to create AzureOpenAI client: {e}\n")
        return results  # all None

    for start in range(0, len(texts), batch_size):
        chunk = texts[start : start + batch_size]

        # Retry loop for this chunk only
        attempt = 0
        while attempt <= max_retries:
            try:
                resp = client.embeddings.create(model=EMBED_MODEL, input=chunk)
                resp.data.sort(key=lambda x: x.index)  # preserve original order

                for i, d in enumerate(resp.data):
                    results[start + i] = np.array(d.embedding, dtype=np.float32)

                break  # success, move to next batch

            except Exception as e:
                attempt += 1
                # Parse retry-after seconds if available (Azure puts it in the
                # error message sometimes) – default to 5 * attempt seconds.
                wait_secs = 5 * attempt
                if "retry" in str(e).lower():
                    # crude extraction of the first integer in the message
                    import re as _re

                    m = _re.search(r"retry after (\\d+)", str(e).lower())
                    if m:
                        wait_secs = int(m.group(1))

                if diag_path:
                    with open(diag_path, "a") as f:
                        f.write(
                            f"[EMBEDDING-ERROR] Batch {start}-{start+len(chunk)-1} attempt {attempt}/{max_retries}: {e}. Waiting {wait_secs}s.\n"
                        )

                if attempt > max_retries:
                    # Give up on this batch – leave None placeholders.
                    break

                time.sleep(wait_secs)

        # End retry loop

        # Respect base rate-limit between *successful* calls only.
        if attempt == 0 or results[start] is not None:
            time.sleep(1)  # more conservative than 0.1s previously

    return results


def _cosine(u: np.ndarray, v: np.ndarray) -> float:
    """Cosine similarity."""
    return float(np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v) + 1e-8))

# Helper functions

def verbalize(user_prompt: str, sys_prompt: str, model, tokenizer, args) -> str:
    """Call Llama model for chat completion."""
    messages = [
        {"role": "system", "content": sys_prompt},
        {"role": "user", "content": (
            "/no_think" + "Strictly answer ONLY from the perspective of the given persona. Do NOT give generic or general answers. Your response must be specific to the persona described above. " + sys_prompt + user_prompt
        )},
    ]
    # Use apply_chat_template
    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        enable_thinking=False
    )
    input_ids = input_ids.to(model.device)  # Ensure input_ids are on the same device as the model

    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            max_new_tokens=1200,
            temperature=0.85,
            use_cache=True,
            do_sample=True,
            min_p=0.1
        )
    
    response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
    return response.strip()

# CLI

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Run static-persona evaluation over a slice of the LAMBDA stories_test dataset.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="static_folder", help="Directory to write JSON results.")
    parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use for inference.")
    parser.add_argument("--full_eval", action="store_true", help="(Optional) run full evaluation over the entire set; not used by the parallel runner.")
    parser.add_argument("--similarity_json", type=str, default=None, help="Path to JSON produced by similarityengine.py containing pre-computed 5-nearest neighbours per video.")
    return parser.parse_args()

# Main evaluation logic (chunk mode)

def main() -> None:
    global model, tokenizer
    args = parse_args()

    # If the user explicitly asks for full evaluation, fallback to the old behaviour.
    if args.full_eval:
        print("--full_eval requested; please revert to the previous version of the script.")
        sys.exit(1)

    start_idx: int = max(0, args.start)
    output_dir: str = args.output_dir
    os.makedirs(output_dir, exist_ok=True)

    # -----------------------------------------------------------------------------
    # Load Model
    # -----------------------------------------------------------------------------
    model_name = "meta-llama/Llama-3.3-70B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto"
    )

    # -----------------------------------------------------------------------------
    # Load datasets
    # -----------------------------------------------------------------------------
    cache_root = "hf_cache"

    lambda_dataset = load_dataset(
        "behavior-in-the-wild/LAMBDA",
        split="test",
        cache_dir=cache_root,
    )

    # Load stories CSV
    STORIES_CSV_PATH = "stories_test.csv"
    story_rows: List[Dict] = []
    with open(STORIES_CSV_PATH, newline="", encoding="utf-8") as csv_f:
        reader = csv.DictReader(csv_f)
        for row in reader:
            try:
                row["video_id"] = int(row["video_id"])
                story_rows.append(row)
            except Exception:
                continue  # skip malformed rows

    # Map video_id -> record from HF dataset for quick lookup
    dataset_map: Dict[int, Dict] = {int(r["video_id"]): r for r in lambda_dataset}

    # Build working records list (retain global index for embedding lookup)
    all_records: List[Dict] = []
    for row in story_rows:
        vid = row["video_id"]
        if vid not in dataset_map:
            continue  # require recall score
        rec = dataset_map[vid]
        all_records.append({
            "video_id": vid,
            "story": row["story"],
            "recall_score": rec["recall_score"],
            "global_idx": len(all_records),
        })

    # Map video_id -> index in all_records for quick lookup (needed when using sim_map)
    vid_to_idx_map: Dict[int, int] = {rec["video_id"]: i for i, rec in enumerate(all_records)}

    # -----------------------------------------------------------------------------
    # Optionally load a pre-computed similarity map if provided
    # -----------------------------------------------------------------------------
    sim_map: Optional[Dict[str, List[Dict]]] = None
    if args.similarity_json:
        if not os.path.isfile(args.similarity_json):
            print(f"[WARNING] --similarity_json provided but file not found: {args.similarity_json}")
        else:
            with open(args.similarity_json, "r", encoding="utf-8") as _f:
                sim_map = json.load(_f)
            print(f"Loaded pre-computed similarity map from {args.similarity_json} (entries: {len(sim_map)})")

    # -----------------------------------------------------------------------------
    # Compute embeddings for every story (once) to enable similarity search
    # -----------------------------------------------------------------------------
    # Diagnostics file to capture embedding-related errors
    diag_path = os.path.join(output_dir, EMBED_DIAG_FILENAME)

    if sim_map is None:
        _story_texts = [rec_["story"] for rec_ in all_records]
        try:
            all_embeddings = _get_embeddings(_story_texts, diag_path=diag_path)
        except Exception as e:
            print(f"Embedding generation failed; falling back to random sampling: {e}")
            with open(diag_path, "a") as f:
                f.write(f"[GENERAL-ERROR] {e}\n")
            all_embeddings = [None] * len(all_records)
    else:
        # We won't need embeddings because we'll use the pre-computed neighbours.
        all_embeddings = [None] * len(all_records)

    # -----------------------------------------------------------------------------
    # Slice records for this worker
    # -----------------------------------------------------------------------------
    end_idx = args.end if args.end is not None else len(all_records) - 1
    end_idx = min(end_idx, len(all_records) - 1)
    slice_records = all_records[start_idx:end_idx + 1]
    print(f"Processing slice {start_idx}–{end_idx} (n={len(slice_records)})")

    # -----------------------------------------------------------------------------
    # Main loop
    # -----------------------------------------------------------------------------
    results: List[Dict] = []

    for record in tqdm(slice_records, desc="Static-persona eval"):
        target_text = record["story"]

        # If we have a pre-computed neighbour map, fetch the corresponding indices now.
        precomputed_ids = None
        if sim_map and str(record["video_id"]) in sim_map:
            neighbour_dicts = sim_map[str(record["video_id"])]
            precomputed_ids = [vid_to_idx_map.get(nb["video_id"]) for nb in neighbour_dicts]
            # Drop any None entries and truncate to 5 just in case
            precomputed_ids = [pid for pid in precomputed_ids if pid is not None][:5]

        # --------------------------------------------------------------
        # Select 5 examples via semantic similarity (fallback -> random)
        # --------------------------------------------------------------
        idx_global = record["global_idx"]
        other_ids = list(range(len(all_records)))
        other_ids.remove(idx_global)

        SIMILARITY_THRESHOLD = 0.75 # Adjust as needed

        # Log similarity scores to a file
        similarity_log_path = os.path.join(output_dir, f"similarity_scores_{start_idx}_{end_idx}.txt")
        with open(similarity_log_path, 'a') as log_file:
            log_file.write(f"Processing video_id: {record['video_id']}\n")
            if all_embeddings[idx_global] is not None:
                # Only compute similarities for non-None embeddings
                valid_pairs = []
                for j in other_ids:
                    if all_embeddings[j] is not None:
                        sim = _cosine(all_embeddings[idx_global], all_embeddings[j])
                        valid_pairs.append((sim, j))
                        log_file.write(f"Similarity with video_id {all_records[j]['video_id']}: {sim}\n")
                
                if len(valid_pairs) >= 5:
                    # Filter by threshold
                    above_thresh = [(sim, j) for sim, j in valid_pairs if sim >= SIMILARITY_THRESHOLD]
                    if len(above_thresh) >= 5:
                        # Pick top 5 above threshold
                        above_thresh.sort(reverse=True)
                        sample_ids = [j for _, j in above_thresh[:5]]
                        log_file.write("Selected top 5 above threshold:\n")
                        for sim, j in above_thresh[:5]:
                            log_file.write(f"  video_id {all_records[j]['video_id']}: {sim}\n")
                    elif len(above_thresh) > 0:
                        # Fewer than 5 above threshold, use only those
                        sample_ids = [j for _, j in above_thresh]
                        log_file.write("Selected all above threshold (less than 5):\n")
                        for sim, j in above_thresh:
                            log_file.write(f"  video_id {all_records[j]['video_id']}: {sim}\n")
                    else:
                        # No examples above threshold, fallback to top 5 most similar
                        valid_pairs.sort(reverse=True)
                        sample_ids = [j for _, j in valid_pairs[:5]]
                        log_file.write("Fallback to top 5 most similar:\n")
                        for sim, j in valid_pairs[:5]:
                            log_file.write(f"  video_id {all_records[j]['video_id']}: {sim}\n")
                else:
                    # Not enough valid embeddings for similarity-based selection
                    # Fall back to random sampling from valid indices
                    valid_indices = [j for j in other_ids if all_embeddings[j] is not None]
                    if len(valid_indices) >= 5:
                        sample_ids = random.sample(valid_indices, k=5)
                        log_file.write("Fallback to random sampling from valid indices (5 or more available):\n")
                        for j in sample_ids:
                            log_file.write(f"  video_id {all_records[j]['video_id']}\n")
                    else:
                        # Use all valid indices if fewer than 5
                        sample_ids = valid_indices
                        log_file.write("Fallback to using all valid indices (less than 5):\n")
                        for j in sample_ids:
                            log_file.write(f"  video_id {all_records[j]['video_id']}\n")
            else:
                # Target embedding is None, fall back to random sampling
                sample_ids = random.sample(other_ids, k=min(5, len(other_ids)))
                log_file.write("Target embedding is None, fallback to random sampling:\n")
                for j in sample_ids:
                    log_file.write(f"  video_id {all_records[j]['video_id']}\n")
            log_file.write("\n")

        # Override with pre-computed neighbour indices if available
        if precomputed_ids is not None:
            sample_ids = precomputed_ids
            # Add a log message to confirm that the pre-computed neighbours are being used.
            with open(similarity_log_path, 'a') as log_file:
                log_file.write(f"--> Overrode sample with pre-computed neighbours for video {record['video_id']}\n")

            # Always log the final sample IDs that will be used for few-shot examples
            with open(similarity_log_path, 'a') as log_file:
                log_file.write("Final sample set used:\n")
                for sid in sample_ids:
                    log_file.write(f"  video_id {all_records[sid]['video_id']}\n")
                log_file.write("\n\n")

        example_blocks = []
        for sid in sample_ids:
            ex = all_records[sid]
            text = ex["story"]
            score = round(float(ex["recall_score"]) * 10, 2)
            example_blocks.append(f"{text}\nScore: {score}")
        examples_text = "\n---\n".join(example_blocks)

        persona_predictions: Dict[str, Dict] = {}
        for persona_name, sys_prompt in persona_prompts.items():
            user_prompt = (
                "Below are five example video ads and their memorability scores. "
                "After these, you'll see a new ad. Estimate its long-term memorability for the persona on a "
                "0–10 scale (0 = instantly forgotten, 10 = unforgettable even weeks later). "
                "Anchor your judgment relative to the example scores and consider narrative strength, emotional resonance, uniqueness, brand fit, pacing, and any sensory hooks described. "
                "Decimals (e.g., 6.7) are allowed for precision.\n\n"
                "Examples:\n" + examples_text + 
                "\n---\n" + target_text + 
                "\n\nQuestion: What is the calibrated memorability score of this video (0–10)?"
            )
            resp_text = verbalize(user_prompt, sys_prompt, model, tokenizer, args)
            numbers = re.findall(r"\b\d+(?:\.\d+)?\b", resp_text)
            pred = float(numbers[-1]) if numbers else None
            # Clamp prediction to [0, 10]
            if pred is not None:
                pred = max(0, min(10, pred))
            persona_predictions[persona_name] = {
                "prediction": pred,
                "response": resp_text,
                "prompt": user_prompt,
            }

        # Compute mean prediction across all personas (ignoring None)
        valid_preds = [v["prediction"] for v in persona_predictions.values() if v["prediction"] is not None]
        mean_pred = sum(valid_preds) / len(valid_preds) if valid_preds else None

        results.append({
            "video_id": int(record["video_id"]),
            "ground_truth": float(record["recall_score"]) * 10,
            "predictions": persona_predictions,
            "mean_prediction": mean_pred,
        })

        # Dump every 5 examples to avoid data loss on pre-emption
        if len(results) % 1 == 0:
            out_path = os.path.join(output_dir, f"static_results_{start_idx}_{end_idx}.json")
            with open(out_path, "w") as f:
                json.dump(results, f, indent=2)

    # Final save
    out_path = os.path.join(output_dir, f"static_results_{start_idx}_{end_idx}.json")
    with open(out_path, "w") as f:
        json.dump(results, f, indent=2)


if __name__ == "__main__":
    main()


            

            

        
        
        

        
