import argparse
import csv
import json
import os
import random
import re
import sys
import time
from typing import Dict, List, Optional

import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
from openai import AzureOpenAI

# --------------------------------------------------------------------------------------
# Model
# --------------------------------------------------------------------------------------
model, tokenizer = None, None

baseline_system_prompt = (
    "You are an expert evaluator of video advertisement memorability. "
    "Given structured details about a video advertisement, your job is to "
    "predict its long-term memorability score for a general audience on a "
    "0–10 scale (where 0 means instantly forgotten and 10 means "
    "unforgettable even weeks later). You can provide precise scores including decimal values (e.g., 6.7).\n\n"
    "Think about narrative strength, emotional resonance, uniqueness, brand "
    "fit, pacing, and any visual or auditory hooks described.\n\n"
    "Return your answer in this exact two-line format:\n"
    "Reason: <brief justification>\n"
    "Answer: <numeric score 0-10>\n"
    "Include only one number (0-10) after 'Answer:' and no other numbers in the response."
)

# --------------------------------------------------------------------------------------
# Embedding model configuration
# --------------------------------------------------------------------------------------
EMBED_MODEL = "text-embedding-3-small"

# Path for global diagnostics (created in output_dir later)
EMBED_DIAG_FILENAME = "embedding_diagnostics.txt"

# Azure OpenAI Configuration (copied from gpt_static_thisruns.py)
api_version = "2024-02-15-preview"
config_dict: Dict[str, str] = {
    "api_key": "YOUR_OPENAI_API_KEY",
    "api_version": api_version,
    "azure_endpoint": "https://your-azure-openai-endpoint/",
}

def _get_embeddings(
    texts: List[str],
    batch_size: int = 96,
    diag_path: Optional[str] = None,
    max_retries: int = 5,
) -> List[Optional[np.ndarray]]:
    """Compute embeddings for *texts* using Azure OpenAI with robust retry logic.

    For each *batch_size* slice of *texts* we attempt up to *max_retries* times;
    on a 429 (rate-limit) or transient network error we back-off exponentially.

    Unlike the previous implementation we *do not* abort the entire run when a
    single batch fails.  Instead, we insert ``None`` placeholders for every text
    in the failed batch so downstream code can still leverage the successful
    embeddings it has and fall back to random sampling only for the missing
    ones.
    """

    # Pre-allocate output so ordering is preserved even with failures.
    results: List[Optional[np.ndarray]] = [None] * len(texts)

    # Build Azure client once (outside the loop for efficiency)
    try:
        client = AzureOpenAI(
            api_key=os.getenv("OPENAI_API_KEY", config_dict["api_key"]),
            api_version=config_dict["api_version"],
            azure_endpoint=config_dict["azure_endpoint"],
        )
    except Exception as e:
        if diag_path:
            with open(diag_path, "a") as f:
                f.write(f"[CLIENT-INIT-ERROR] Failed to create AzureOpenAI client: {e}\n")
        return results  # all None

    for start in range(0, len(texts), batch_size):
        chunk = texts[start : start + batch_size]

        # Retry loop for this chunk only
        attempt = 0
        while attempt <= max_retries:
            try:
                resp = client.embeddings.create(model=EMBED_MODEL, input=chunk)
                resp.data.sort(key=lambda x: x.index)  # preserve original order

                for i, d in enumerate(resp.data):
                    results[start + i] = np.array(d.embedding, dtype=np.float32)

                break  # success, move to next batch

            except Exception as e:
                attempt += 1
                # Parse retry-after seconds if available (Azure puts it in the
                # error message sometimes) – default to 5 * attempt seconds.
                wait_secs = 5 * attempt
                if "retry" in str(e).lower():
                    # crude extraction of the first integer in the message
                    import re as _re

                    m = _re.search(r"retry after (\\d+)", str(e).lower())
                    if m:
                        wait_secs = int(m.group(1))

                if diag_path:
                    with open(diag_path, "a") as f:
                        f.write(
                            f"[EMBEDDING-ERROR] Batch {start}-{start+len(chunk)-1} attempt {attempt}/{max_retries}: {e}. Waiting {wait_secs}s.\n"
                        )

                if attempt > max_retries:
                    # Give up on this batch – leave None placeholders.
                    break

                time.sleep(wait_secs)

        # End retry loop

        # Respect base rate-limit between *successful* calls only.
        if attempt == 0 or results[start] is not None:
            time.sleep(0.25)  # more conservative than 0.1s previously

    return results


def _cosine(u: np.ndarray, v: np.ndarray) -> float:
    """Cosine similarity between two vectors."""
    return float(np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v) + 1e-8))

# --------------------------------------------------------------------------------------
# Helper functions
# --------------------------------------------------------------------------------------

def verbalize(user_prompt: str, sys_prompt: str, model, tokenizer, args) -> str:
    """Call Qwen model for chat completion."""
    messages = [
        {"role": "system", "content": sys_prompt},
        {"role": "user", "content": "/no_think" + user_prompt},
    ]
    # Use apply_chat_template
    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        enable_thinking=False
    )
    input_ids = input_ids.to(model.device)  # Ensure input_ids are on the same device as the model

    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            max_new_tokens=1200,
            temperature=0.85,
            use_cache=True,
            do_sample=True,
            min_p=0.1
        )
    
    response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
    return response.strip()

# --------------------------------------------------------------------------------------
# CLI
# --------------------------------------------------------------------------------------

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Run static-persona evaluation over a slice of the LAMBDA stories_test dataset.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="static_folder", help="Directory to write JSON results.")
    parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use for inference.")
    parser.add_argument("--full_eval", action="store_true", help="(Optional) run full evaluation over the entire set; not used by the parallel runner.")
    parser.add_argument("--similarity_json", type=str, default=None, help="Path to JSON produced by similarityengine.py containing pre-computed 5-nearest neighbours per video.")
    return parser.parse_args()

# --------------------------------------------------------------------------------------
# Main evaluation logic (chunk mode)
# --------------------------------------------------------------------------------------

def main() -> None:
    global model, tokenizer
    args = parse_args()

    # If the user explicitly asks for full evaluation, fallback to the old behaviour.
    if args.full_eval:
        print("--full_eval requested; please revert to the previous version of the script.")
        sys.exit(1)

    start_idx: int = max(0, args.start)
    output_dir: str = args.output_dir
    os.makedirs(output_dir, exist_ok=True)

    # -----------------------------------------------------------------------------
    # Load Model
    # -----------------------------------------------------------------------------
    model_name = "Qwen/Qwen3-32B"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto"
    )

    # -----------------------------------------------------------------------------
    # Load datasets
    # -----------------------------------------------------------------------------
    cache_root = "hf_cache"

    lambda_dataset = load_dataset(
        "behavior-in-the-wild/LAMBDA",
        split="test",
        cache_dir=cache_root,
    )

    # Load stories CSV
    STORIES_CSV_PATH = "stories_test.csv"
    story_rows: List[Dict] = []
    with open(STORIES_CSV_PATH, newline="", encoding="utf-8") as csv_f:
        reader = csv.DictReader(csv_f)
        for row in reader:
            try:
                row["video_id"] = int(row["video_id"])
                story_rows.append(row)
            except Exception:
                continue  # skip malformed rows

    # Map video_id -> record from HF dataset for quick lookup
    dataset_map: Dict[int, Dict] = {int(r["video_id"]): r for r in lambda_dataset}

    # Build working records list (retain global index for embedding lookup)
    all_records: List[Dict] = []
    for row in story_rows:
        vid = row["video_id"]
        if vid not in dataset_map:
            continue  # require recall score
        rec = dataset_map[vid]
        all_records.append({
            "video_id": vid,
            "story": row["story"],
            "recall_score": rec["recall_score"],
            "global_idx": len(all_records),
        })

    # Map video_id -> index in all_records for quick lookup (needed when using sim_map)
    vid_to_idx_map: Dict[int, int] = {rec["video_id"]: i for i, rec in enumerate(all_records)}

    # Optionally load a pre-computed similarity map if provided
    sim_map: Optional[Dict[str, List[Dict]]] = None
    if args.similarity_json:
        if not os.path.isfile(args.similarity_json):
            print(f"[WARNING] --similarity_json provided but file not found: {args.similarity_json}")
        else:
            with open(args.similarity_json, "r", encoding="utf-8") as _f:
                sim_map = json.load(_f)
            print(f"Loaded pre-computed similarity map from {args.similarity_json} (entries: {len(sim_map)})")

    # -----------------------------------------------------------------------------
    # Compute embeddings for every story (once) to enable similarity search
    # -----------------------------------------------------------------------------
    # Diagnostics file to capture embedding-related errors
    diag_path = os.path.join(output_dir, EMBED_DIAG_FILENAME)

    if sim_map is None:
        _story_texts = [rec_["story"] for rec_ in all_records]
        try:
            all_embeddings = _get_embeddings(_story_texts, diag_path=diag_path)
        except Exception as e:
            print(f"Embedding generation failed; falling back to random sampling: {e}")
            with open(diag_path, "a") as f:
                f.write(f"[GENERAL-ERROR] {e}\n")
            all_embeddings = [None] * len(all_records)
    else:
        # We won't need embeddings because we'll use the pre-computed neighbours.
        all_embeddings = [None] * len(all_records)

    # -----------------------------------------------------------------------------
    # Slice records for this worker
    # -----------------------------------------------------------------------------
    end_idx = args.end if args.end is not None else len(all_records) - 1
    end_idx = min(end_idx, len(all_records) - 1)
    slice_records = all_records[start_idx:end_idx + 1]
    print(f"Processing slice {start_idx}–{end_idx} (n={len(slice_records)})")

    # -----------------------------------------------------------------------------
    # Main loop
    # -----------------------------------------------------------------------------
    results: List[Dict] = []

    for record in tqdm(slice_records, desc="Static-persona eval"):
        target_text = record["story"]

        # If we have a pre-computed neighbour map, fetch the corresponding indices now.
        precomputed_ids = None
        if sim_map and str(record["video_id"]) in sim_map:
            neighbour_dicts = sim_map[str(record["video_id"])]
            precomputed_ids = [vid_to_idx_map.get(nb["video_id"]) for nb in neighbour_dicts]
            # Drop any None entries and truncate to 5 just in case
            precomputed_ids = [pid for pid in precomputed_ids if pid is not None][:5]

        # --------------------------------------------------------------
        # Select 5 examples via semantic similarity (fallback -> random)
        # --------------------------------------------------------------
        idx_global = record["global_idx"]
        other_ids = list(range(len(all_records)))
        other_ids.remove(idx_global)

        SIMILARITY_THRESHOLD = 0.75 # Adjust as needed

        # Log similarity scores to a file
        similarity_log_path = os.path.join(output_dir, f"similarity_scores_{start_idx}_{end_idx}.txt")
        with open(similarity_log_path, 'a') as log_file:
            log_file.write(f"Processing video_id: {record['video_id']}\n")
            if all_embeddings[idx_global] is not None:
                # Only compute similarities for non-None embeddings
                valid_pairs = []
                for j in other_ids:
                    if all_embeddings[j] is not None:
                        sim = _cosine(all_embeddings[idx_global], all_embeddings[j])
                        valid_pairs.append((sim, j))
                        log_file.write(f"Similarity with video_id {all_records[j]['video_id']}: {sim}\n")
                
                if len(valid_pairs) >= 5:
                    # Filter by threshold
                    above_thresh = [(sim, j) for sim, j in valid_pairs if sim >= SIMILARITY_THRESHOLD]
                    if len(above_thresh) >= 5:
                        # Pick top 5 above threshold
                        above_thresh.sort(reverse=True)
                        sample_ids = [j for _, j in above_thresh[:5]]
                        log_file.write("Selected top 5 above threshold:\n")
                        for sim, j in above_thresh[:5]:
                            log_file.write(f"  video_id {all_records[j]['video_id']}: {sim}\n")
                    elif len(above_thresh) > 0:
                        # Fewer than 5 above threshold, use only those
                        sample_ids = [j for _, j in above_thresh]
                        log_file.write("Selected all above threshold (less than 5):\n")
                        for sim, j in above_thresh:
                            log_file.write(f"  video_id {all_records[j]['video_id']}: {sim}\n")
                    else:
                        # No examples above threshold, fallback to top 5 most similar
                        valid_pairs.sort(reverse=True)
                        sample_ids = [j for _, j in valid_pairs[:5]]
                        log_file.write("Fallback to top 5 most similar:\n")
                        for sim, j in valid_pairs[:5]:
                            log_file.write(f"  video_id {all_records[j]['video_id']}: {sim}\n")
                else:
                    # Not enough valid embeddings for similarity-based selection
                    # Fall back to random sampling from valid indices
                    valid_indices = [j for j in other_ids if all_embeddings[j] is not None]
                    if len(valid_indices) >= 5:
                        sample_ids = random.sample(valid_indices, k=5)
                        log_file.write("Fallback to random sampling from valid indices (5 or more available):\n")
                        for j in sample_ids:
                            log_file.write(f"  video_id {all_records[j]['video_id']}\n")
                    else:
                        # Use all valid indices if fewer than 5
                        sample_ids = valid_indices
                        log_file.write("Fallback to using all valid indices (less than 5):\n")
                        for j in sample_ids:
                            log_file.write(f"  video_id {all_records[j]['video_id']}\n")
            else:
                # Target embedding is None, fall back to random sampling
                sample_ids = random.sample(other_ids, k=min(5, len(other_ids)))
                log_file.write("Target embedding is None, fallback to random sampling:\n")
                for j in sample_ids:
                    log_file.write(f"  video_id {all_records[j]['video_id']}\n")
            log_file.write("\n")

        # Override with pre-computed neighbour indices if available
        if precomputed_ids is not None:
            sample_ids = precomputed_ids
            # Add a log message to confirm that the pre-computed neighbours are being used.
            with open(similarity_log_path, 'a') as log_file:
                log_file.write(f"--> Overrode sample with pre-computed neighbours for video {record['video_id']}\n")

            # Always log the final sample IDs that will be used for few-shot examples
            with open(similarity_log_path, 'a') as log_file:
                log_file.write("Final sample set used:\n")
                for sid in sample_ids:
                    log_file.write(f"  video_id {all_records[sid]['video_id']}\n")
                log_file.write("\n\n")

        example_blocks = []
        for sid in sample_ids:
            ex = all_records[sid]
            text = ex["story"]
            score = round(float(ex["recall_score"]) * 10, 2)
            example_blocks.append(f"{text}\nScore: {score}")
        examples_text = "\n---\n".join(example_blocks)

        # -------------------------
        # Single no-persona inference
        # -------------------------
        user_prompt = (
            "Below are five example video ads and their memorability scores. "
            "After these, you'll see a new ad. Give its long-term memorability "
            "score (0–10).\n\n"
            + examples_text
            + "\n---\n"
            + target_text
            + "\n\nQuestion: What is the long-term memorability score of this video (0-10)?"
        )

        resp_text = verbalize(user_prompt, baseline_system_prompt, model, tokenizer, args)
        numbers = re.findall(r"\b\d+(?:\.\d+)?\b", resp_text)
        pred = float(numbers[-1]) if numbers else None
        if pred is not None:
            pred = max(0, min(10, pred))  # Clamp to valid range

        results.append({
            "video_id": int(record["video_id"]),
            "ground_truth": float(record["recall_score"]) * 10,
            "prediction": pred,
            "mean_prediction": pred,  # Preserve key expected by downstream metrics
            "prompt": user_prompt,
            "model_response": resp_text,
        })

        # Dump every 5 examples to avoid data loss on pre-emption
        if len(results) % 1 == 0:
            out_path = os.path.join(output_dir, f"static_results_{start_idx}_{end_idx}.json")
            with open(out_path, "w") as f:
                json.dump(results, f, indent=2)

    # Final save
    out_path = os.path.join(output_dir, f"static_results_{start_idx}_{end_idx}.json")
    with open(out_path, "w") as f:
        json.dump(results, f, indent=2)


if __name__ == "__main__":
    main()


            

            

        
        
        

        
