import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import pandas as pd
import argparse
import os
import sys
from typing import Dict, List, Optional
from collections import Counter
import re  # Added for parsing persona responses

def parse_args() -> argparse.Namespace:
    """Parse command-line arguments for parallel execution."""
    parser = argparse.ArgumentParser(
        description="Run persona-based topic classification over a slice of the dataset.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="topic_results", help="Directory to write JSON results.")
    parser.add_argument("--csv_path", type=str, required=True, help="Path to the input CSV with columns video_id,story")
    return parser.parse_args()

# ---------------------------------------------------------------------------
# Updated persona prompts in Reason/Answer format
# ---------------------------------------------------------------------------
persona_prompts = {
    "18-24_female": """You are a woman aged 18–24 who intuitively understands what resonates with your generation—bold aesthetics, authenticity, humor, pop culture references, and individuality.

You will be given a *topic_vocab* dictionary that lists marketing topics (e.g. restaurants, electronics, charities) with short definitions. You are then shown the **story** of a video advertisement.

Your task is to choose the SINGLE most relevant topic key from *topic_vocab* that best describes the advertisement.

Return exactly two lines:
Reason: <brief justification in one sentence>
Answer: <topic_key>
Only output the topic key after "Answer:".""",

    "18-24_male": """You are a man aged 18–24 who knows what grabs young men's attention—humor, edge, cultural references, and visual flair.

You will be given a *topic_vocab* dictionary that lists marketing topics (e.g. restaurants, electronics, charities) with short definitions. You are then shown the **story** of a video advertisement.

Your task is to select the SINGLE most relevant topic key from *topic_vocab* that best fits the story.

Return exactly two lines:
Reason: <brief justification in one sentence>
Answer: <topic_key>
Only output the topic key after "Answer:".""",

    "25-34_female": """You are a woman aged 25–34 who connects with content that is visually refined, emotionally resonant, and aligned with lifestyle interests—career, wellness, and relationships.

You will be given a *topic_vocab* dictionary that lists marketing topics with short definitions. You are then shown the **story** of a video advertisement.

Your task is to choose the SINGLE most relevant topic key from *topic_vocab* that best represents the ad.

Return exactly two lines:
Reason: <brief justification in one sentence>
Answer: <topic_key>
Only output the topic key after "Answer:".""",

    "25-34_male": """You are a man aged 25–34 who appreciates content that shows ambition, clarity, innovation, fitness, and smart humor.

You will be given a *topic_vocab* dictionary that lists marketing topics with short definitions. You are then shown the **story** of a video advertisement.

Your task is to choose the SINGLE most relevant topic key from *topic_vocab* that best matches the story.

Return exactly two lines:
Reason: <brief justification in one sentence>
Answer: <topic_key>
Only output the topic key after "Answer:".""",

    "35-44_female": """You are a woman aged 35–44 who is drawn to emotionally intelligent storytelling, depth, and purpose.

You will be given a *topic_vocab* dictionary that lists marketing topics with short definitions. You are then shown the **story** of a video advertisement.

Your task is to select the SINGLE most relevant topic key from *topic_vocab* that best reflects the story.

Return exactly two lines:
Reason: <brief justification in one sentence>
Answer: <topic_key>
Only output the topic key after "Answer:".""",

    "35-44_male": """You are a man aged 35–44 who connects with grounded, aspirational content about family, success, and purpose.

You will be given a *topic_vocab* dictionary that lists marketing topics with short definitions. You are then shown the **story** of a video advertisement.

Your task is to select the SINGLE most relevant topic key from *topic_vocab* that best fits the narrative.

Return exactly two lines:
Reason: <brief justification in one sentence>
Answer: <topic_key>
Only output the topic key after "Answer:".""",

    "45-54_female": """You are a woman aged 45–54 who appreciates visuals and stories that carry meaning, clarity, and purpose.

You will be given a *topic_vocab* dictionary that lists marketing topics with short definitions. You are then shown the **story** of a video advertisement.

Your task is to choose the SINGLE most relevant topic key from *topic_vocab* that best matches the ad.

Return exactly two lines:
Reason: <brief justification in one sentence>
Answer: <topic_key>
Only output the topic key after "Answer:".""",

    "45-54_male": """You are a man aged 45–54 who values storytelling that emphasizes responsibility, growth, trust, and wisdom.

You will be given a *topic_vocab* dictionary that lists marketing topics with short definitions. You are then shown the **story** of a video advertisement.

Your task is to select the SINGLE most relevant topic key from *topic_vocab* that aligns best with the story.

Return exactly two lines:
Reason: <brief justification in one sentence>
Answer: <topic_key>
Only output the topic key after "Answer:".""",

    "55+_female": """You are a woman aged 55 or older who resonates with content that conveys warmth, legacy, and deep emotional meaning.

You will be given a *topic_vocab* dictionary that lists marketing topics with short definitions. You are then shown the **story** of a video advertisement.

Your task is to select the SINGLE most relevant topic key from *topic_vocab* that best describes the story.

Return exactly two lines:
Reason: <brief justification in one sentence>
Answer: <topic_key>
Only output the topic key after "Answer:".""",

    "55+_male": """You are a man aged 55 or older who prefers storytelling with sincerity, meaning, and timeless values.

You will be given a *topic_vocab* dictionary that lists marketing topics with short definitions. You are then shown the **story** of a video advertisement.

Your task is to select the SINGLE most relevant topic key from *topic_vocab* that best fits the story.

Return exactly two lines:
Reason: <brief justification in one sentence>
Answer: <topic_key>
Only output the topic key after "Answer:".""",
}


# The topic vocabulary remains the same
topics = "Topic_vocab : {'restaurant': 'Restaurants, cafe, fast food', 'chocolate': 'Chocolate, cookies, candy, ice cream', 'chips': 'Chips, snacks, nuts, fruit, gum, cereal, yogurt, soups', 'seasoning': 'Seasoning, condiments, ketchup', 'petfood': 'Petfood', 'alcohol': 'Alcohol', 'coffee': 'Coffee, tea', 'soda': 'Soda, juice, milk, energy drinks, water', 'cars': 'Cars, automobile sales, parts, insurance, repair, gas, motor oil', 'electronics': 'Electronics (computers, laptops, tablets, cellphones, TVs)', 'phone_tv_internet_providers': 'Phone, TV and internet service providers', 'financial': 'Financial services (banks, credit cards, investment firms)', 'education': 'Education (universities, colleges, kindergarten, online degrees)', 'security': 'Security and safety services (anti-theft, safety courses)', 'software': 'Software (internet radio, streaming, job search website, grammar correction, travel planning)', 'other_service': 'Other services (dating, tax, legal, loan, religious, printing, catering)', 'beauty': 'Beauty products and cosmetics (deodorants, toothpaste, makeup, hair products, laser hair removal)', 'healthcare': 'Healthcare and medications (hospitals, health insurance, allergy, cold remedy, home tests, vitamins)', 'clothing': 'Clothing and accessories (jeans, shoes, eye glasses, handbags, watches, jewelry)', 'baby': 'Baby products (food, sippy cups, diapers)', 'game': 'Games and toys (including video and mobile games)', 'cleaning': 'Cleaning products (detergents, fabric softeners, soap, tissues, paper towels)', 'home_improvement': 'Home improvements and repairs (furniture, decoration, lawn care, plumbing)', 'home_appliance': 'Home appliances (coffee makers, dishwashers, cookware, vacuum cleaners, heaters, music players)', 'travel': 'Vacation and travel (airlines, cruises, theme parks, hotels, travel agents)', 'media': 'Media and arts (TV shows, movies, musicals, books, audio books)', 'sports': 'Sports equipment and activities', 'shopping': 'Shopping (department stores, drug stores, groceries)', 'gambling': 'Gambling (lotteries, casinos)', 'environment': 'Environment, nature, pollution, wildlife', 'animal_right': 'Animal rights, animal abuse', 'human_right': 'Human rights', 'safety': 'Safety, safe driving, fire safety', 'smoking_alcohol_abuse': 'Smoking, alcohol abuse', 'domestic_violence': 'Domestic violence', 'self_esteem': 'Self esteem, bullying, cyber bullying', 'political': 'Political candidates', 'charities': 'Charities'}"

def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    global model, tokenizer
    model_name = "Qwen/Qwen3-32B"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", load_in_4bit=True)

    # Load CSV data
    try:
        df = pd.read_csv(args.csv_path)
    except Exception as e:
        print(f"Error reading CSV {args.csv_path}: {e}")
        sys.exit(1)

    all_records = df.to_dict(orient='records')

    # Determine slice for this run
    start_idx = args.start
    end_idx = len(all_records) - 1 if args.end is None else min(args.end, len(all_records) - 1)
    slice_records = all_records[start_idx : end_idx + 1]

    print(f"Processing slice {start_idx}–{end_idx} (n={len(slice_records)})")

    results = []
    output_path = os.path.join(args.output_dir, f"topic_results_{start_idx}_{end_idx}.json")

    for rec in tqdm(slice_records, desc=f"Persona-Topic Eval {start_idx}-{end_idx}"):
        try:
            video_id = str(rec.get('video_id', '')).strip()
            story_text = rec.get('story', '')
            cleaned_text = ' '.join(str(story_text).split()).replace('\n', '').replace('\f', '')

            persona_predictions = {}
            for persona_name, sys_prompt in persona_prompts.items():
                messages = [
                    {"role": "system", "content": sys_prompt},
                    {
                        "role": "user", 
                        "content":"/no_think" + sys_prompt + f"{topics}\n\nStory: {cleaned_text}"
                    }
                ]
                
                try:
                    # Qwen inference
                    input_ids = tokenizer.apply_chat_template(
                        messages,
                        tokenize=True,
                        add_generation_prompt=True,
                        return_tensors="pt",
                        enable_thinking=False
                    ).to(model.device)

                    with torch.no_grad():
                        outputs = model.generate(
                            input_ids=input_ids,
                            max_new_tokens=300,
                            temperature=0.85,
                            do_sample=True,
                            min_p=0.1,
                        )

                    raw_resp = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True).strip()
 
                    # Extract reason (first line starting with Reason:)
                    reason_match = re.search(r"(?i)^reason:\s*(.+)$", raw_resp, re.MULTILINE)
                    reason_text = reason_match.group(1).strip() if reason_match else ""

                    # Extract topic key from the 'Answer:' line (case-insensitive)
                    answer_match = re.search(r"(?i)^answer:\s*([^\s\.,;\n]+)", raw_resp, re.MULTILINE)
                    if answer_match:
                        pred_topic = answer_match.group(1).strip().lower().strip("'\". ,")
                    else:
                        # Fallback: take last word of the response (after stripping punctuation)
                        pred_topic = raw_resp.split()[-1].lower().strip("'\". ,")

                    persona_predictions[persona_name] = {
                        'topic': pred_topic,
                        'reason': reason_text,
                        'raw_response': raw_resp,
                    }
                except Exception as e:
                    print(f"Error during Qwen inference for key {video_id}, persona {persona_name}: {e}")
                    persona_predictions[persona_name] = "error"

            # Majority vote for the final topic
            if persona_predictions:
                # Collect topics excluding errors
                valid_preds = [p['topic'] for p in persona_predictions.values() if p['topic'] != "error"]
                if valid_preds:
                    final_topic = Counter(valid_preds).most_common(1)[0][0]
                else:
                    final_topic = "error_no_valid_predictions"
            else:
                final_topic = "error_no_predictions"

            # Store results
            result_item = {
                'video_id': video_id,
                'url': f"https://www.youtube.com/watch?v={video_id}" if video_id else "",
                'story': cleaned_text,
                'persona_predictions': persona_predictions,
                'final_topic': final_topic
            }
            results.append(result_item)
            
            # Incremental save
            with open(output_path, 'w') as f:
                json.dump(results, f, indent=4)

        except Exception as e:
            print(f"Error processing key {video_id}: {e}")
            continue

    print(f"Finished processing. Results saved to {output_path}")

if __name__ == "__main__":
    main()




