import json
import openai
from tqdm import tqdm
import pandas as pd
import argparse
import os
import sys
from typing import Dict, List, Optional

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model, tokenizer = None, None  # Will be loaded in main()

def parse_args() -> argparse.Namespace:
    """Parse command-line arguments for parallel execution."""
    parser = argparse.ArgumentParser(
        description="Run persona-based topic classification over a slice of the dataset.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="topic_results", help="Directory to write JSON results.")
    parser.add_argument("--csv_path", type=str, required=True, help="Path to the input CSV with columns video_id,story")
    return parser.parse_args()

# ---------------------------------------------------------------------------
# Zero-shot classification system prompt
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = (
    "You are an expert content analyst. You will be given a dictionary called topic_vocab, which lists marketing topics "
    "and their definitions. You will also be given the STORY text of a video advertisement. "
    "Your task is to choose the SINGLE most relevant topic key from topic_vocab that best describes the advertisement. "
    "Output ONLY the topic key, nothing else."
)

# The topic vocabulary remains the same
topics = "Topic_vocab : {'restaurant': 'Restaurants, cafe, fast food', 'chocolate': 'Chocolate, cookies, candy, ice cream', 'chips': 'Chips, snacks, nuts, fruit, gum, cereal, yogurt, soups', 'seasoning': 'Seasoning, condiments, ketchup', 'petfood': 'Petfood', 'alcohol': 'Alcohol', 'coffee': 'Coffee, tea', 'soda': 'Soda, juice, milk, energy drinks, water', 'cars': 'Cars, automobile sales, parts, insurance, repair, gas, motor oil', 'electronics': 'Electronics (computers, laptops, tablets, cellphones, TVs)', 'phone_tv_internet_providers': 'Phone, TV and internet service providers', 'financial': 'Financial services (banks, credit cards, investment firms)', 'education': 'Education (universities, colleges, kindergarten, online degrees)', 'security': 'Security and safety services (anti-theft, safety courses)', 'software': 'Software (internet radio, streaming, job search website, grammar correction, travel planning)', 'other_service': 'Other services (dating, tax, legal, loan, religious, printing, catering)', 'beauty': 'Beauty products and cosmetics (deodorants, toothpaste, makeup, hair products, laser hair removal)', 'healthcare': 'Healthcare and medications (hospitals, health insurance, allergy, cold remedy, home tests, vitamins)', 'clothing': 'Clothing and accessories (jeans, shoes, eye glasses, handbags, watches, jewelry)', 'baby': 'Baby products (food, sippy cups, diapers)', 'game': 'Games and toys (including video and mobile games)', 'cleaning': 'Cleaning products (detergents, fabric softeners, soap, tissues, paper towels)', 'home_improvement': 'Home improvements and repairs (furniture, decoration, lawn care, plumbing)', 'home_appliance': 'Home appliances (coffee makers, dishwashers, cookware, vacuum cleaners, heaters, music players)', 'travel': 'Vacation and travel (airlines, cruises, theme parks, hotels, travel agents)', 'media': 'Media and arts (TV shows, movies, musicals, books, audio books)', 'sports': 'Sports equipment and activities', 'shopping': 'Shopping (department stores, drug stores, groceries)', 'gambling': 'Gambling (lotteries, casinos)', 'environment': 'Environment, nature, pollution, wildlife', 'animal_right': 'Animal rights, animal abuse', 'human_right': 'Human rights', 'safety': 'Safety, safe driving, fire safety', 'smoking_alcohol_abuse': 'Smoking, alcohol abuse', 'domestic_violence': 'Domestic violence', 'self_esteem': 'Self esteem, bullying, cyber bullying', 'political': 'Political candidates', 'charities': 'Charities'}"

def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    global model, tokenizer
    model_name = "Qwen/Qwen3-7B"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")

    # Load CSV data
    try:
        df = pd.read_csv(args.csv_path)
    except Exception as e:
        print(f"Error reading CSV {args.csv_path}: {e}")
        sys.exit(1)

    all_records = df.to_dict(orient='records')

    # Determine slice for this run
    start_idx = args.start
    end_idx = len(all_records) - 1 if args.end is None else min(args.end, len(all_records) - 1)
    slice_records = all_records[start_idx : end_idx + 1]

    print(f"Processing slice {start_idx}–{end_idx} (n={len(slice_records)})")

    results = []
    output_path = os.path.join(args.output_dir, f"topic_results_{start_idx}_{end_idx}.json")

    for rec in tqdm(slice_records, desc=f"Persona-Topic Eval {start_idx}-{end_idx}"):
        try:
            video_id = str(rec.get('video_id', '')).strip()
            story_text = rec.get('story', '')
            cleaned_text = ' '.join(str(story_text).split()).replace('\n', '').replace('\f', '')

            # Build zero-shot prompt
            messages = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": f"{topics}\n\nStory: {cleaned_text}"},
            ]

            try:
                # Build chat template for Qwen
                input_ids = tokenizer.apply_chat_template(
                    messages,
                    tokenize=True,
                    add_generation_prompt=True,
                    return_tensors="pt",
                ).to(model.device)

                with torch.no_grad():
                    outputs = model.generate(
                        input_ids=input_ids,
                        max_new_tokens=50,
                        temperature=0.7,
                        do_sample=True,
                    )

                resp_text = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
                pred_topic = resp_text.strip().lower().split()[0].strip("'\". ,")
            except Exception as e:
                print(f"Error during Qwen inference for key {video_id}: {e}")
                pred_topic = "error_model"

            # Store results
            result_item = {
                'video_id': video_id,
                'url': f"https://www.youtube.com/watch?v={video_id}" if video_id else "",
                'story': cleaned_text,
                'predicted_topic': pred_topic,
            }
            results.append(result_item)
            
            # Incremental save
            with open(output_path, 'w') as f:
                json.dump(results, f, indent=4)

        except Exception as e:
            print(f"Error processing key {video_id}: {e}")
            continue

    print(f"Finished processing. Results saved to {output_path}")

if __name__ == "__main__":
    main()




