import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import pandas as pd
import argparse
import os
import sys
from typing import Dict, List, Optional
from collections import Counter
import re  # regex parsing
import random  # For sampling distractor reasons

def parse_args() -> argparse.Namespace:
    """Parse command-line arguments for parallel execution."""
    parser = argparse.ArgumentParser(
        description="Run persona-based topic classification over a slice of the dataset.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="topic_results", help="Directory to write JSON results.")
    parser.add_argument("--csv_path", type=str, required=True, help="Path to the input CSV with columns video_id,story")
    parser.add_argument("--annotation_path", type=str, default="reaction_annotation.json", help="Path to reaction_annotation.json for sampling distractor reasons")
    return parser.parse_args()

# ---------------------------------------------------------------------------
# Updated persona prompts in Reason/Answer format
# ---------------------------------------------------------------------------
persona_prompts = {
    "18-24_female": """You are a woman aged 18–24 who intuitively understands what resonates with your generation—bold aesthetics, authenticity, humor, pop culture references, and individuality.

You will be shown (1) the STORY of a video advertisement and (2) a LIST OF REASONS that state why someone might respond positively to the ad.

Your job: Choose the SINGLE best reason from the list that fits the story.

Return EXACTLY one line:
Answer: <reason>
The line must contain only the word *Answer:* followed by the chosen reason (verbatim or its list number). Do not output anything else.""",

    "18-24_male": """You are a man aged 18–24 who knows what grabs young men's attention—humor, edge, cultural references, and visual flair.

You will be shown the STORY of a video advertisement and a LIST OF REASONS about why someone might respond to it.

Pick the single best reason.

Return exactly one line:
Answer: <reason>""",

    "25-34_female": """You are a woman aged 25–34 who connects with content that is visually refined, emotionally resonant, and aligned with lifestyle interests—career, wellness, and relationships.

Given the STORY and a LIST OF REASONS, select the single best reason matching the story.

Return exactly:
Answer: <reason>""",

    "25-34_male": """You are a man aged 25–34 who appreciates content that shows ambition, clarity, innovation, fitness, and smart humor.

Pick ONE reason from the provided list that best explains why the ad is persuasive.

Return exactly:
Answer: <reason>""",

    "35-44_female": """You are a woman aged 35–44 who is drawn to emotionally intelligent storytelling, depth, and purpose.

Choose the single best reason from the list.

Return:
Answer: <reason>""",

    "35-44_male": """You are a man aged 35–44 who connects with grounded, aspirational content about family, success, and purpose.

Pick the best reason.

Return:
Answer: <reason>""",

    "45-54_female": """You are a woman aged 45–54 who appreciates visuals and stories that carry meaning, clarity, and purpose.

Select one best reason from the list.

Return:
Answer: <reason>""",

    "45-54_male": """You are a man aged 45–54 who values storytelling that emphasizes responsibility, growth, trust, and wisdom.

Choose the best reason.

Return:
Answer: <reason>""",

    "55+_female": """You are a woman aged 55 or older who resonates with content that conveys warmth, legacy, and deep emotional meaning.

Pick the single best reason.

Return:
Answer: <reason>""",

    "55+_male": """You are a man aged 55 or older who prefers storytelling with sincerity, meaning, and timeless values.

Select one best reason.

Return:
Answer: <reason>""",
}


# CSV must contain 'reasons' column (JSON array or ';'-separated)

def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    # Setup Azure OpenAI client
    api_version = "2024-02-15-preview"
    config_dict: Dict[str, str] = {
        "api_key": os.getenv("OPENAI_API_KEY", "YOUR_OPENAI_API_KEY"),
        "api_version": api_version,
        "azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT", "https://your-azure-openai-endpoint/"),
    }

    # --------------------------------------------------------------
    # Load Qwen chat model once
    # --------------------------------------------------------------
    global model, tokenizer
    model_name = "Qwen/Qwen3-7B"  # small enough for single-GPU/CPU use; adjust if needed
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto",
        load_in_4bit=True,
    )

    # -------------------------------------------------------------------
    # Load full reaction annotation to draw distractor reasons
    # -------------------------------------------------------------------
    try:
        with open(args.annotation_path, "r") as f:
            annotation_data = json.load(f)
    except Exception as e:
        print(f"Error reading annotation JSON {args.annotation_path}: {e}")
        sys.exit(1)

    # Flatten all reasons into a single pool we can sample from
    all_reasons_pool = [reason for reasons in annotation_data.values() for reason in reasons]

    # Load CSV data
    try:
        df = pd.read_csv(args.csv_path)
    except Exception as e:
        print(f"Error reading CSV {args.csv_path}: {e}")
        sys.exit(1)

    all_records = df.to_dict(orient='records')

    # Determine slice for this run
    start_idx = args.start
    end_idx = len(all_records) - 1 if args.end is None else min(args.end, len(all_records) - 1)
    slice_records = all_records[start_idx : end_idx + 1]

    print(f"Processing slice {start_idx}–{end_idx} (n={len(slice_records)})")

    results = []
    output_path = os.path.join(args.output_dir, f"topic_results_{start_idx}_{end_idx}.json")

    for rec in tqdm(slice_records, desc=f"Persona-Topic Eval {start_idx}-{end_idx}"):
        try:
            video_id = str(rec.get('video_id', '')).strip()
            story_text = rec.get('story', '')
            # -------------------------------------------------------------------
            # Retrieve the five ground-truth reasons for this video
            # Priority: (1) annotation file; (2) CSV column 'reasons'
            # -------------------------------------------------------------------

            correct_reasons = []

            # 1) Try annotation JSON
            if video_id in annotation_data:
                correct_reasons = annotation_data[video_id]

            # 2) Fallback to CSV column if still empty
            if not correct_reasons:
                reasons_raw = rec.get('reasons', '')
                try:
                    correct_reasons = json.loads(reasons_raw) if isinstance(reasons_raw, str) else reasons_raw
                except Exception:
                    correct_reasons = [r.strip() for r in str(reasons_raw).split(';') if r.strip()]

            # Ensure list and filter blanks
            if isinstance(correct_reasons, str):
                correct_reasons = [correct_reasons]
            correct_reasons = [r for r in correct_reasons if r]

            if not correct_reasons:
                print(f"No reasons found for id {video_id}; skipping")
                continue

            # -------------------------------------------------------------------
            # Build candidate list: 5 correct + 25 random distractors
            # -------------------------------------------------------------------
            distractor_pool = [r for r in all_reasons_pool if r not in correct_reasons]
            num_distractors = 25 if len(distractor_pool) >= 25 else len(distractor_pool)
            distractor_reasons = random.sample(distractor_pool, num_distractors)

            candidate_reasons = correct_reasons + distractor_reasons
            random.shuffle(candidate_reasons)

            cleaned_text = ' '.join(str(story_text).split()).replace('\n', '').replace('\f', '')

            persona_predictions = {}
            for persona_name, sys_prompt in persona_prompts.items():
                messages = [
                    {"role": "system", "content": sys_prompt},
                    {
                        "role": "user",
                        "content": f"Story:\n{cleaned_text}\n\nList of reasons:\n" + "\n".join(f"{i+1}. {r}" for i, r in enumerate(candidate_reasons)) + "\n\nReturn exactly one line:\nAnswer: <reason>"
                    }
                ]
                
                try:
                    # Qwen inference
                    input_ids = tokenizer.apply_chat_template(
                        messages,
                        tokenize=True,
                        add_generation_prompt=True,
                        return_tensors="pt",
                        enable_thinking=False,
                    ).to(model.device)

                    with torch.no_grad():
                        outputs = model.generate(
                            input_ids=input_ids,
                            max_new_tokens=300,
                            temperature=0.85,
                            do_sample=True,
                            min_p=0.1,
                        )

                    raw_resp = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True).strip()
 
                    # Parse Answer line
                    ans_match = re.search(r"(?i)^answer:\s*(.+)$", raw_resp, re.MULTILINE)
                    chosen = ans_match.group(1).strip() if ans_match else raw_resp.strip()

                    # If answer is a digit index, map to reason text
                    if chosen.isdigit():
                        idx_int = int(chosen)
                        if 1 <= idx_int <= len(candidate_reasons):
                            chosen = candidate_reasons[idx_int-1]

                    # Extract optional justification
                    reason_match = re.search(r"(?i)^reason:\s*(.+)$", raw_resp, re.MULTILINE)
                    justif = reason_match.group(1).strip() if reason_match else ""

                    persona_predictions[persona_name] = {
                        'reason': chosen,
                        'explanation': justif,
                        'raw': raw_resp,
                    }
                except Exception as e:
                    print(f"Error during OpenAI call for key {video_id}, persona {persona_name}: {e}")
                    persona_predictions[persona_name] = "error"

            # Majority vote for the final topic
            if persona_predictions:
                # Collect topics excluding errors
                valid_preds = [p['reason'] for p in persona_predictions.values() if p['reason'] != "error"]
                if valid_preds:
                    final_topic = Counter(valid_preds).most_common(1)[0][0]
                else:
                    final_topic = "error_no_valid_predictions"
            else:
                final_topic = "error_no_predictions"

            # Store results
            result_item = {
                'video_id': video_id,
                'url': f"https://www.youtube.com/watch?v={video_id}" if video_id else "",
                'story': cleaned_text,
                'persona_predictions': persona_predictions,
                'final_reason': final_topic,
                'candidate_reasons': candidate_reasons,
                'correct_reasons': correct_reasons,
            }
            results.append(result_item)
            
            # Incremental save
            with open(output_path, 'w') as f:
                json.dump(results, f, indent=4)

        except Exception as e:
            print(f"Error processing key {video_id}: {e}")
            continue

    print(f"Finished processing. Results saved to {output_path}")

if __name__ == "__main__":
    main()




