import json
import openai
from tqdm import tqdm
import pandas as pd
import argparse
import os
import sys
from typing import Dict, List, Optional
from collections import Counter
from openai import AzureOpenAI
import re  # regex parsing
import random  # For sampling distractor reasons

def parse_args() -> argparse.Namespace:
    """Parse command-line arguments for parallel execution."""
    parser = argparse.ArgumentParser(
        description="Run persona-based topic classification over a slice of the dataset.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="topic_results", help="Directory to write JSON results.")
    parser.add_argument("--csv_path", type=str, required=True, help="Path to the input CSV with columns video_id,story")
    parser.add_argument("--annotation_path", type=str, default="action_annotation.json", help="Path to action_annotation.json for sampling distractor actions")
    return parser.parse_args()

# ---------------------------------------------------------------------------
# Updated persona prompts in Action/Answer format
# ---------------------------------------------------------------------------
persona_prompts = {
    "18-24_female": """You are a woman aged 18–24 who intuitively understands what resonates with your generation—bold aesthetics, authenticity, humor, pop culture references, and individuality.

You will be shown (1) the STORY of a video advertisement and (2) a LIST OF ACTIONS that a viewer might take after watching it.

Choose the SINGLE best action.

Return EXACTLY two lines:
Answer: <action>
Reason: <brief justification>""",

    "18-24_male": """You are a man aged 18–24 who knows what grabs young men's attention—humor, edge, cultural references, and visual flair.

You will be shown the STORY of a video advertisement and a LIST OF ACTIONS that viewers might consider.

Pick the single best action.

Return exactly two lines:
Answer: <action>
Reason: <brief justification>""",

    "25-34_female": """You are a woman aged 25–34 who connects with content that is visually refined, emotionally resonant, and aligned with lifestyle interests—career, wellness, and relationships.

Given the STORY and a LIST OF ACTIONS, select the single best action that fits the story.

Return exactly two lines:
Answer: <action>
Reason: <brief justification>""",

    "25-34_male": """You are a man aged 25–34 who appreciates content that shows ambition, clarity, innovation, fitness, and smart humor.

Pick ONE action from the provided list that best reflects what a viewer should do.

Return exactly two lines:
Answer: <action>
Reason: <brief justification>""",

    "35-44_female": """You are a woman aged 35–44 who is drawn to emotionally intelligent storytelling, depth, and purpose.

Choose the single best action from the list.

Return exactly two lines:
Answer: <action>
Reason: <brief justification>""",

    "35-44_male": """You are a man aged 35–44 who connects with grounded, aspirational content about family, success, and purpose.

Pick the best action.

Return exactly two lines:
Answer: <action>
Reason: <brief justification>""",

    "45-54_female": """You are a woman aged 45–54 who appreciates visuals and stories that carry meaning, clarity, and purpose.

Select one best action from the list.

Return exactly two lines:
Answer: <action>
Reason: <brief justification>""",

    "45-54_male": """You are a man aged 45–54 who values storytelling that emphasizes responsibility, growth, trust, and wisdom.

Choose the best action.

Return exactly two lines:
Answer: <action>
Reason: <brief justification>""",

    "55+_female": """You are a woman aged 55 or older who resonates with content that conveys warmth, legacy, and deep emotional meaning.

Pick the single best action.

Return exactly two lines:
Answer: <action>
Reason: <brief justification>""",

    "55+_male": """You are a man aged 55 or older who prefers storytelling with sincerity, meaning, and timeless values.

Select one best action.

Return exactly two lines:
Answer: <action>
Reason: <brief justification>""",
}


# CSV must contain 'reasons' column (JSON array or ';'-separated)

def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    # Setup Azure OpenAI client
    api_version = "2024-02-15-preview"
    config_dict: Dict[str, str] = {
        "api_key": os.getenv("OPENAI_API_KEY", "YOUR_OPENAI_API_KEY"),
        "api_version": api_version,
        "azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT", "https://your-azure-openai-endpoint/"),
    }
    client = AzureOpenAI(
        api_key=config_dict["api_key"],
        api_version=config_dict["api_version"],
        azure_endpoint=config_dict["azure_endpoint"],
    )

    # -------------------------------------------------------------------
    # Load full reaction annotation to draw distractor reasons
    # -------------------------------------------------------------------
    try:
        with open(args.annotation_path, "r") as f:
            annotation_data = json.load(f)
    except Exception as e:
        print(f"Error reading annotation JSON {args.annotation_path}: {e}")
        sys.exit(1)

    # Flatten all reasons into a single pool we can sample from
    all_reasons_pool = [reason for reasons in annotation_data.values() for reason in reasons]

    # Load CSV data
    try:
        df = pd.read_csv(args.csv_path)
    except Exception as e:
        print(f"Error reading CSV {args.csv_path}: {e}")
        sys.exit(1)

    all_records = df.to_dict(orient='records')

    # Determine slice for this run
    start_idx = args.start
    end_idx = len(all_records) - 1 if args.end is None else min(args.end, len(all_records) - 1)
    slice_records = all_records[start_idx : end_idx + 1]

    print(f"Processing slice {start_idx}–{end_idx} (n={len(slice_records)})")

    results = []
    output_path = os.path.join(args.output_dir, f"topic_results_{start_idx}_{end_idx}.json")

    for rec in tqdm(slice_records, desc=f"Persona-Topic Eval {start_idx}-{end_idx}"):
        try:
            video_id = str(rec.get('video_id', '')).strip()
            story_text = rec.get('story', '')
            # -------------------------------------------------------------------
            # Retrieve the five ground-truth reasons for this video
            # Priority: (1) annotation file; (2) CSV column 'reasons'
            # -------------------------------------------------------------------

            correct_reasons = []

            # 1) Try annotation JSON
            if video_id in annotation_data:
                correct_reasons = annotation_data[video_id]

            # 2) Fallback to CSV column if still empty
            if not correct_reasons:
                reasons_raw = rec.get('reasons', '')
                try:
                    correct_reasons = json.loads(reasons_raw) if isinstance(reasons_raw, str) else reasons_raw
                except Exception:
                    correct_reasons = [r.strip() for r in str(reasons_raw).split(';') if r.strip()]

            # Ensure list and filter blanks
            if isinstance(correct_reasons, str):
                correct_reasons = [correct_reasons]
            correct_reasons = [r for r in correct_reasons if r]

            if not correct_reasons:
                print(f"No reasons found for id {video_id}; skipping")
                continue

            # -------------------------------------------------------------------
            # Build candidate list: 5 correct + 25 random distractors
            # -------------------------------------------------------------------
            distractor_pool = [r for r in all_reasons_pool if r not in correct_reasons]
            num_distractors = 25 if len(distractor_pool) >= 25 else len(distractor_pool)
            distractor_reasons = random.sample(distractor_pool, num_distractors)

            candidate_reasons = correct_reasons + distractor_reasons
            random.shuffle(candidate_reasons)

            cleaned_text = ' '.join(str(story_text).split()).replace('\n', '').replace('\f', '')

            persona_predictions = {}
            for persona_name, sys_prompt in persona_prompts.items():
                messages = [
                    {"role": "system", "content": sys_prompt},
                    {
                        "role": "user",
                        "content": f"Story:\n{cleaned_text}\n\nList of actions:\n" + "\n".join(f"{i+1}. {r}" for i, r in enumerate(candidate_reasons)) + "\n\nReturn exactly two lines:\nAnswer: <action>\nReason: <brief justification>"
                    }
                ]
                
                try:
                    response = client.chat.completions.create(
                        model="gpt-4o",
                        messages=messages,
                        max_tokens=300,  # allow space for Reason + Answer
                        temperature=0.85,
                        n=1,
                    )
                    raw_resp = response.choices[0].message.content.strip()
 
                    # Parse Answer line
                    ans_match = re.search(r"(?i)^answer:\s*(.+)$", raw_resp, re.MULTILINE)
                    chosen = ans_match.group(1).strip() if ans_match else raw_resp.strip()

                    # If answer is a digit index, map to reason text
                    if chosen.isdigit():
                        idx_int = int(chosen)
                        if 1 <= idx_int <= len(candidate_reasons):
                            chosen = candidate_reasons[idx_int-1]

                    # Extract optional justification
                    reason_match = re.search(r"(?i)^reason:\s*(.+)$", raw_resp, re.MULTILINE)
                    justif = reason_match.group(1).strip() if reason_match else ""

                    persona_predictions[persona_name] = {
                        'reason': chosen,
                        'explanation': justif,
                        'raw': raw_resp,
                    }
                except Exception as e:
                    print(f"Error during OpenAI call for key {video_id}, persona {persona_name}: {e}")
                    persona_predictions[persona_name] = "error"

            # Majority vote for the final topic
            if persona_predictions:
                # Collect topics excluding errors
                valid_preds = [p['reason'] for p in persona_predictions.values() if p['reason'] != "error"]
                if valid_preds:
                    final_topic = Counter(valid_preds).most_common(1)[0][0]
                else:
                    final_topic = "error_no_valid_predictions"
            else:
                final_topic = "error_no_predictions"

            # Store results
            result_item = {
                'video_id': video_id,
                'url': f"https://www.youtube.com/watch?v={video_id}" if video_id else "",
                'story': cleaned_text,
                'persona_predictions': persona_predictions,
                'final_reason': final_topic,
                'candidate_reasons': candidate_reasons,
                'correct_reasons': correct_reasons,
            }
            results.append(result_item)
            
            # Incremental save
            with open(output_path, 'w') as f:
                json.dump(results, f, indent=4)

        except Exception as e:
            print(f"Error processing key {video_id}: {e}")
            continue

    print(f"Finished processing. Results saved to {output_path}")

if __name__ == "__main__":
    main()




