import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import pandas as pd
import argparse
import os
import sys
from typing import Dict, List, Optional
import random  # For sampling distractor reasons

def parse_args() -> argparse.Namespace:
    """Parse command-line arguments for parallel execution."""
    parser = argparse.ArgumentParser(
        description="Run persona-based topic classification over a slice of the dataset.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="topic_results", help="Directory to write JSON results.")
    parser.add_argument("--csv_path", type=str, required=True, help="Path to the input CSV with columns video_id,story")
    parser.add_argument("--annotation_path", type=str, default="reaction_annotation.json", help="Path to reaction_annotation.json for sampling distractor reasons")
    return parser.parse_args()

# ---------------------------------------------------------------------------
# Zero-shot system prompt for best-reason selection
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = (
    "You will be given the STORY of a video advertisement and a numbered list "
    "of five candidate reasons a viewer might give for taking the recommended action. "
    "Choose EXACTLY ONE reason that best explains why a viewer should take the action, and output that reason verbatim. "
    "Do not output any additional text—just the chosen reason."
)

# CSV expected columns: video_id, story, reasons (JSON list or ';'-separated)

def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    # Setup Azure OpenAI client
    api_version = "2024-02-15-preview"
    config_dict: Dict[str, str] = {
        "api_key": os.getenv("OPENAI_API_KEY", "YOUR_OPENAI_API_KEY"),
        "api_version": api_version,
        "azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT", "https://your-azure-openai-endpoint/"),
    }

    # --------------------------------------------------------------
    # Load Qwen chat model once
    # --------------------------------------------------------------
    global model, tokenizer
    model_name = "Qwen/Qwen3-7B"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto",
        load_in_4bit=True,
    )

    # -------------------------------------------------------------------
    # Load full reaction annotation to draw distractor reasons
    # -------------------------------------------------------------------
    try:
        with open(args.annotation_path, "r") as f:
            annotation_data = json.load(f)
    except Exception as e:
        print(f"Error reading annotation JSON {args.annotation_path}: {e}")
        sys.exit(1)

    # Flatten all reasons into a single pool we can sample from
    all_reasons_pool = [reason for reasons in annotation_data.values() for reason in reasons]

    # Load CSV data
    try:
        df = pd.read_csv(args.csv_path)
    except Exception as e:
        print(f"Error reading CSV {args.csv_path}: {e}")
        sys.exit(1)

    all_records = df.to_dict(orient='records')

    # Determine slice for this run
    start_idx = args.start
    end_idx = len(all_records) - 1 if args.end is None else min(args.end, len(all_records) - 1)
    slice_records = all_records[start_idx : end_idx + 1]

    print(f"Processing slice {start_idx}–{end_idx} (n={len(slice_records)})")

    results = []
    output_path = os.path.join(args.output_dir, f"topic_results_{start_idx}_{end_idx}.json")

    for rec in tqdm(slice_records, desc=f"Persona-Topic Eval {start_idx}-{end_idx}"):
        try:
            video_id = str(rec.get('video_id', '')).strip()
            story_text = rec.get('story', '')
            # -------------------------------------------------------------------
            # Retrieve correct reasons (5) for this video
            # Priority: annotation file > CSV column fallback
            # -------------------------------------------------------------------

            correct_reasons = []

            # 1) annotation JSON
            if video_id in annotation_data:
                correct_reasons = annotation_data[video_id]

            # 2) fallback CSV column
            if not correct_reasons:
                reasons_raw = rec.get('reasons', '')
                try:
                    correct_reasons = json.loads(reasons_raw) if isinstance(reasons_raw, str) else reasons_raw
                except Exception:
                    correct_reasons = [r.strip() for r in str(reasons_raw).split(';') if r.strip()]

            # Clean list
            if isinstance(correct_reasons, str):
                correct_reasons = [correct_reasons]
            correct_reasons = [r for r in correct_reasons if r]

            if not correct_reasons:
                print(f"No reasons for id {video_id}; skipping")
                continue

            # Build candidate list: 5 correct + 25 random distractors
            distractor_pool = [r for r in all_reasons_pool if r not in correct_reasons]
            num_distractors = 25 if len(distractor_pool) >= 25 else len(distractor_pool)
            distractor_reasons = random.sample(distractor_pool, num_distractors)

            candidate_reasons = correct_reasons + distractor_reasons
            random.shuffle(candidate_reasons)

            cleaned_text = ' '.join(str(story_text).split()).replace('\n', '').replace('\f', '')

            # Build prompt with candidate reasons list
            reasons_block = "\n".join(f"{i+1}. {r}" for i, r in enumerate(candidate_reasons))
            user_content = (
                f"Story:\n{cleaned_text}\n\nList of reasons:\n{reasons_block}\n\n"
                "Return exactly one line:\nAnswer: <reason>"
            )

            messages = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_content},
            ]

            try:
                # Qwen inference
                input_ids = tokenizer.apply_chat_template(
                    messages,
                    tokenize=True,
                    add_generation_prompt=True,
                    return_tensors="pt",
                    enable_thinking=False,
                ).to(model.device)

                with torch.no_grad():
                    outputs = model.generate(
                        input_ids=input_ids,
                        max_new_tokens=120,
                        temperature=0.0,
                        do_sample=False,
                    )

                raw_resp = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True).strip()

                # Try to extract after 'Answer:' if provided
                import re as _re
                ans_match = _re.search(r"(?i)^answer:\s*(.+)$", raw_resp, _re.MULTILINE)
                chosen_reason = ans_match.group(1).strip() if ans_match else raw_resp.strip()

                # If answer is a digit, map to candidate reasons
                if chosen_reason.isdigit():
                    idx_int = int(chosen_reason)
                    if 1 <= idx_int <= len(candidate_reasons):
                        chosen_reason = candidate_reasons[idx_int-1]
            except Exception as e:
                print(f"Error during OpenAI call for key {video_id}: {e}")
                chosen_reason = "error_api"

            # Store results
            result_item = {
                'video_id': video_id,
                'url': f"https://www.youtube.com/watch?v={video_id}" if video_id else "",
                'story': cleaned_text,
                'predicted_reason': chosen_reason,
                'candidate_reasons': candidate_reasons,
                'correct_reasons': correct_reasons,
            }
            results.append(result_item)
            
            # Incremental save
            with open(output_path, 'w') as f:
                json.dump(results, f, indent=4)

        except Exception as e:
            print(f"Error processing key {video_id}: {e}")
            continue

    print(f"Finished processing. Results saved to {output_path}")

if __name__ == "__main__":
    main()




