import json
import openai
from tqdm import tqdm
import pandas as pd
import argparse
import os
import sys
from typing import Dict, List, Optional
from openai import AzureOpenAI
import random  # For sampling distractor reasons

def parse_args() -> argparse.Namespace:
    """Parse command-line arguments for parallel execution."""
    parser = argparse.ArgumentParser(
        description="Run persona-based topic classification over a slice of the dataset.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="topic_results", help="Directory to write JSON results.")
    parser.add_argument("--csv_path", type=str, required=True, help="Path to the input CSV with columns video_id,story")
    parser.add_argument("--annotation_path", type=str, default="action_annotation.json", help="Path to action_annotation.json for sampling distractor actions")
    return parser.parse_args()

# ---------------------------------------------------------------------------
# Zero-shot system prompt for best-action selection
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = (
    "You will be given the STORY of a video advertisement and a numbered list "
    "of candidate ACTIONS that a viewer might take after watching it. "
    "Choose EXACTLY ONE action that best fits the story. "
    "Return EXACTLY two lines:\n"
    "Answer: <action>\nReason: <brief justification>."
)

# CSV expected columns: video_id, story, reasons (JSON list or ';'-separated)

def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    # Setup Azure OpenAI client
    api_version = "2024-02-15-preview"
    config_dict: Dict[str, str] = {
        "api_key": os.getenv("OPENAI_API_KEY", ""),
        "api_version": api_version,
        "azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT", "https://your-azure-openai-endpoint/"),
    }
    client = AzureOpenAI(
        api_key=config_dict["api_key"],
        api_version=config_dict["api_version"],
        azure_endpoint=config_dict["azure_endpoint"],
    )

    # -------------------------------------------------------------------
    # Load full action annotation to draw distractor actions
    # -------------------------------------------------------------------
    try:
        with open(args.annotation_path, "r") as f:
            annotation_data = json.load(f)
    except Exception as e:
        print(f"Error reading annotation JSON {args.annotation_path}: {e}")
        sys.exit(1)

    # Flatten all actions into a single pool we can sample from
    all_actions_pool = [act for acts in annotation_data.values() for act in acts]

    # Load CSV data
    try:
        df = pd.read_csv(args.csv_path)
    except Exception as e:
        print(f"Error reading CSV {args.csv_path}: {e}")
        sys.exit(1)

    all_records = df.to_dict(orient='records')

    # Determine slice for this run
    start_idx = args.start
    end_idx = len(all_records) - 1 if args.end is None else min(args.end, len(all_records) - 1)
    slice_records = all_records[start_idx : end_idx + 1]

    print(f"Processing slice {start_idx}–{end_idx} (n={len(slice_records)})")

    results = []
    output_path = os.path.join(args.output_dir, f"topic_results_{start_idx}_{end_idx}.json")

    for rec in tqdm(slice_records, desc=f"Persona-Topic Eval {start_idx}-{end_idx}"):
        try:
            video_id = str(rec.get('video_id', '')).strip()
            story_text = rec.get('story', '')
            # -------------------------------------------------------------------
            # Retrieve correct actions for this video
            # Priority: annotation file > CSV column fallback
            # -------------------------------------------------------------------

            correct_actions = []

            # 1) annotation JSON
            if video_id in annotation_data:
                correct_actions = annotation_data[video_id]

            # 2) fallback CSV column
            if not correct_actions:
                actions_raw = rec.get('reasons', '')  # fallback column name legacy
                try:
                    correct_actions = json.loads(actions_raw) if isinstance(actions_raw, str) else actions_raw
                except Exception:
                    correct_actions = [r.strip() for r in str(actions_raw).split(';') if r.strip()]

            # Clean list
            if isinstance(correct_actions, str):
                correct_actions = [correct_actions]
            correct_actions = [r for r in correct_actions if r]

            if not correct_actions:
                print(f"No actions for id {video_id}; skipping")
                continue

            # Build candidate list: 5 correct + 25 random distractors
            distractor_pool = [a for a in all_actions_pool if a not in correct_actions]
            num_distractors = 25 if len(distractor_pool) >= 25 else len(distractor_pool)
            distractor_actions = random.sample(distractor_pool, num_distractors)

            candidate_actions = correct_actions + distractor_actions
            random.shuffle(candidate_actions)

            cleaned_text = ' '.join(str(story_text).split()).replace('\n', '').replace('\f', '')

            # Build prompt with candidate actions list
            actions_block = "\n".join(f"{i+1}. {a}" for i, a in enumerate(candidate_actions))
            user_content = (
                f"Story:\n{cleaned_text}\n\nList of actions:\n{actions_block}\n\n"
                "Return exactly two lines:\nAnswer: <action>\nReason: <brief justification>"
            )

            messages = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_content},
            ]

            try:
                response = client.chat.completions.create(
                    model="gpt-4o",
                    messages=messages,
                    max_tokens=50,
                    temperature=0.0,
                    n=1,
                )
                raw_resp = response.choices[0].message.content.strip()

                # Try to extract after 'Answer:' if provided
                import re as _re
                ans_match = _re.search(r"(?i)^answer:\s*(.+)$", raw_resp, _re.MULTILINE)
                chosen_action = ans_match.group(1).strip() if ans_match else raw_resp.strip()

                # Extract optional justification
                just_match = _re.search(r"(?i)^reason:\s*(.+)$", raw_resp, _re.MULTILINE)
                justification = just_match.group(1).strip() if just_match else ""

                # If answer is a digit, map to candidate actions
                if chosen_action.isdigit():
                    idx_int = int(chosen_action)
                    if 1 <= idx_int <= len(candidate_actions):
                        chosen_action = candidate_actions[idx_int-1]
            except Exception as e:
                print(f"Error during OpenAI call for key {video_id}: {e}")
                chosen_action = "error_api"
                justification = ""

            # Store results
            result_item = {
                'video_id': video_id,
                'url': f"https://www.youtube.com/watch?v={video_id}" if video_id else "",
                'story': cleaned_text,
                'predicted_action': chosen_action,
                'explanation': justification,
                'candidate_actions': candidate_actions,
                'correct_actions': correct_actions,
            }
            results.append(result_item)
            
            # Incremental save
            with open(output_path, 'w') as f:
                json.dump(results, f, indent=4)

        except Exception as e:
            print(f"Error processing key {video_id}: {e}")
            continue

    print(f"Finished processing. Results saved to {output_path}")

if __name__ == "__main__":
    main()




