import os
import json
import sys
import re
import random
import numpy as np
from openai import AzureOpenAI
from tqdm import tqdm
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--count_personas', action='store_true', help='Output the number of personas')
parser.add_argument('--start', type=int, default=0, help='Start index for dataset slicing')
parser.add_argument('--end', type=int, default=None, help='End index for dataset slicing (inclusive)')
parser.add_argument('--output_dir', type=str, default='results', help='Directory to save per-job JSON outputs')
# NEW ARG: comma-separated list of repetition counts
parser.add_argument('--runs_list', type=str, default='20,40,60,80,100',
                    help='Comma-separated list indicating how many times to repeat the prediction for each datapoint (e.g., "20,40,60")')
# NEW ARGS FOR GMO EVALUATION
parser.add_argument('--gmo', action='store_true', help='Run GMO CTR/CPA evaluation mode (ads) instead of WebAES')
parser.add_argument('--dataset_dir', type=str, default='/path/to/ctr_dataset/',
                    help='Directory containing *.jsonl files for GMO evaluation')
# Prior-run subset reuse
parser.add_argument('--limit', type=int, default=None, help='Total number of samples to evaluate across all dataset files (evenly sampled)')
parser.add_argument('--use_np_subset', action='store_true', help='Use the same datapoints as the prior GPT-CTR-NP run20 merged results')
# Seed for deterministic sampling across personas and run counts
parser.add_argument('--seed', type=int, default=42, help='Random seed to ensure consistent sampling across runs')
args = parser.parse_args()

# ---------------------------------------------------------------------------
# Deterministic seeding
# ---------------------------------------------------------------------------
random.seed(args.seed)
np.random.seed(args.seed)

# Parse runs_list into a list of ints and ensure they are positive
runs_list = [int(x) for x in args.runs_list.split(',') if x.strip()]
runs_list = [r for r in runs_list if r > 0]
if not runs_list:
    raise ValueError("--runs_list must contain at least one positive integer")
persona_prompts = {
    "18-24_female": """You are a digital ad analyst who is also a woman in the 18-24 age group. You're deeply familiar with what resonates with your generation—emotional authenticity, aesthetic quality, bold individuality, and social relevance. You instinctively recognize what appeals to younger women: body positivity, mental health awareness, empowerment, humor, and cultural fluency (like meme literacy or TikTok trends).

You've been shown up to five similar ads with their CTR performance (from 0 to 100), and now you must predict how well a new ad will perform with women your age. Use your personal insight, generational awareness, and pattern recognition to evaluate it.

Return:
Reason: [In one sentence, why you think this ad will perform that way]
Answer: [0–100]""",

    "18-24_male": """You are a digital ad analyst who is also a man in the 18-24 age group. You understand the mindset of young men—seeking confidence, entertainment, edge, and relevance. You're fluent in gaming, influencer culture, memes, and the kind of humor or visual punch that lands with this group.

Given several example ads with performance scores, and a new ad to assess, you must judge how likely it is to capture attention and convert for men your age.

Return:
Reason: [In one sentence, why this ad will or won't work for your group]
Answer: [0–100]""",

    "25-34_female": """You are a digital ad analyst who is also a woman in the 25-34 age group. You understand the balance this age group seeks—between career ambitions, lifestyle goals, personal growth, and relationships. Ads that show aesthetic clarity, empowerment, self-care, and intelligent value propositions resonate well.

You're given several similar ads and their CTR percentiles, followed by a new one to evaluate. Use your cultural fluency and marketing insight to estimate performance among women in your age group.

Return:
Reason: [In one sentence, your age group's likely response and why]
Answer: [0–100]""",

    "25-34_male": """You are a digital ad analyst and a man aged 25–34. You know what appeals to this demographic—ads that are ambitious, tech-savvy, direct, a little edgy, and tied to lifestyle aspiration (fitness, career growth, travel, or finance). You are skeptical of fluff, and you appreciate clarity, wit, and efficiency.

Given past ads with known CTR and a new one, predict its success based on how well it aligns with male values in this age range.

Return:
Reason: [In one sentence, why you predict this level of engagement]
Answer: [0–100]""",

    "35-44_female": """You're a digital ad analyst and a woman aged 35–44. You understand the balance your generation strikes—juggling responsibilities, making informed decisions, but also seeking meaningful and joyful moments. Emotional intelligence, warmth, family, health, and practical luxury matter here.

Given example ads with performance data, evaluate a new ad's resonance with your peers.

Return:
Reason: [In one sentence, how well the ad speaks to this age group's values]
Answer: [0–100]""",

    "35-44_male": """You're a digital ad analyst and a man aged 35–44. You know your generation values trust, utility, and smart messaging. You're discerning, experienced, and don't fall for fluff. You appreciate ads that are well-crafted, respect your intelligence, and deliver real value—especially around career, family, and financial growth.

Given ad examples and a new target ad, predict how well it will land with men in your cohort.

Return:
Reason: [In one sentence, why this ad will or won't resonate]
Answer: [0–100]""",

    "45-54_female": """You're a digital ad analyst and a woman aged 45–54. You represent a generation that values trust, authenticity, and depth. You're less interested in trendiness and more in whether an ad respects your intelligence and lived experience. You care about wellness, family, quality, and emotional clarity.

Reviewing example ads and a new one, estimate how it will perform among women like you.

Return:
Reason: [In one sentence, your interpretation of how this ad fits generational taste]
Answer: [0–100]""",

    "45-54_male": """You're a digital ad analyst and a man aged 45–54. You've seen advertising evolve and know when something is thoughtful versus superficial. You appreciate ads that offer clarity, real-world value, and a touch of inspiration. Trust and practical appeal matter most.

Given ad performance examples and a new one, analyze how it will fare with men your age.

Return:
Reason: [In one sentence, your logic for the predicted score]
Answer: [0–100]""",

    "55+_female": """You are a seasoned digital ad analyst and a woman over 55. You don't just evaluate ads—you see them through decades of shifting media, values, and cultural narratives. You know your peers prefer messaging that's clear, emotionally resonant, warm, and respectful. Themes like wellness, family, community, and security matter deeply.

You're shown similar ad examples with CTRs, followed by a new one. Predict its performance based on emotional tone, clarity, and meaningfulness to older women.

Return:
Reason: [In one sentence, why your age group will or won't respond to this ad]
Answer: [0–100]""",

    "55+_male": """You are a digital ad analyst and a man aged 55 or older. You've experienced enough advertising to know when something is valuable versus manipulative. You favor sincerity, logic, and emotional grounding. You value health, legacy, family, security, and clarity.

Given example ads and a new one, analyze how it will perform for men your age.

Return:
Reason: [In one sentence, your thoughtful rationale]
Answer: [0–100]"""
}
# After parsing args
output_dir = os.path.abspath(args.output_dir)
os.makedirs(output_dir, exist_ok=True)
args.output_dir = output_dir  # overwrite to absolute for consistency

# Path to the previously merged results file from gpt_ctr_np
NP_SUBSET_PATH = "/path/to/final_gmo_results_runs20_merged.json"

def _load_np_subset_prompts(path: str) -> set[str]:
    """Return a set of prompt strings present in the prior NP results file."""
    if not os.path.isfile(path):
        print(f"[ERROR] NP subset file not found: {path}", file=sys.stderr)
        sys.exit(1)
    try:
        with open(path, 'r', encoding='utf-8') as f_in:
            data = json.load(f_in)
    except Exception as e:
        print(f"[ERROR] Failed to load NP subset file {path}: {e}", file=sys.stderr)
        sys.exit(1)
    prompts = {rec.get('prompt', '') for rec in data if isinstance(rec, dict)}
    return prompts

# ---------------------------------------------------------------------------
# NOTE: Legacy WebAES (website likeability) code and prompts have been removed.
# This script now focuses solely on GMO CPA-percentile evaluation.
# ---------------------------------------------------------------------------

# (Image handling utilities removed – not required for CPA evaluation)

# Azure OpenAI Configuration
api_version = "2024-02-15-preview"
config_dict = {
    'api_key': "YOUR_OPENAI_API_KEY",
    'api_version': api_version,
    'azure_endpoint': "https://your-azure-openai-endpoint/"
}

# ----------------------------- Helper Functions -----------------------------

def _verbalize_persona(prompt: str, persona_system_prompt: str) -> str:
    """Call Azure OpenAI with the given persona system prompt."""
    client = AzureOpenAI(
        api_key=config_dict['api_key'],
        api_version=config_dict['api_version'],
        azure_endpoint=config_dict['azure_endpoint'],
    )
    messages = [
        {"role": "system", "content": persona_system_prompt},
        {"role": "user", "content": prompt},
    ]
    resp = client.chat.completions.create(
        model='gpt-4o',
        messages=messages,
        max_tokens=350,
        temperature=0.85,
        n=1,
    )
    return resp.choices[0].message.content.strip()

def _sample_gmo_records(dataset_dir: str, total_limit: int | None):
    """Randomly sample records from each *.jsonl file in `dataset_dir`.

    If `total_limit` is provided, samples are taken as `total_limit // n_files` per file.
    Otherwise, all records from each file are returned.
    """
    file_paths = [os.path.join(dataset_dir, fp) for fp in os.listdir(dataset_dir) if fp.endswith('.jsonl')]
    if not file_paths:
        print(f"[ERROR] No .jsonl files found in {dataset_dir}", file=sys.stderr)
        sys.exit(1)

    random.shuffle(file_paths)  # shuffle to avoid ordering bias
    records = []
    per_file = None
    if total_limit is not None and total_limit > 0:
        per_file = max(1, total_limit // len(file_paths))

    for fp in file_paths:
        with open(fp, 'r', encoding='utf-8') as f_in:
            lines = f_in.readlines()
        if per_file is not None:
            chosen = random.sample(lines, min(per_file, len(lines)))
        else:
            chosen = lines
        for ln in chosen:
            try:
                rec = json.loads(ln)
                rec['_source_file'] = os.path.basename(fp)
                records.append(rec)
            except json.JSONDecodeError:
                continue  # skip malformed lines

    # If we overshot the limit due to rounding, trim back down
    if total_limit is not None and len(records) > total_limit:
        records = random.sample(records, total_limit)
    return records

def run_gmo_evaluation(args):
    """Main entry for GMO evaluation mode with support for multiple repetition counts (runs_list)."""

    # ---------------------------- Record Sampling ----------------------------
    records = _sample_gmo_records(args.dataset_dir, args.limit)

    # Apply optional slicing using --start and --end (1-based inclusive indices)
    slice_start = max(0, args.start)
    slice_end = args.end if args.end is not None else len(records) - 1
    slice_end = min(slice_end, len(records) - 1)
    records = records[slice_start : slice_end + 1]

    print(
        f"[INFO] Running GMO evaluation on {len(records)} sampled records (slice {slice_start}-{slice_end})."
    )

    # Ensure output directory exists
    os.makedirs(args.output_dir, exist_ok=True)

    # Optionally restrict to NP subset prompts
    if args.use_np_subset:
        subset_prompts = _load_np_subset_prompts(NP_SUBSET_PATH)
        before_filter = len(records)
        records = [rec for rec in records if rec.get('prompt', '') in subset_prompts]
        print(f"[INFO] Using NP subset: {len(records)} of {before_filter} records retained.")

    # ----------------------- Iterate over repetition counts ------------------
    for n_runs in runs_list:
        print("\n" + "=" * 80)
        print(f"Running evaluation with {n_runs} repetitions per datapoint…")
        print("=" * 80)

        run_results = []

        # Construct output filename that includes run count so downstream merge works
        out_name = (
            f"gmo_persona_results_runs{n_runs}_samples{args.limit or 'all'}_{slice_start}_{slice_end}.json"
        )
        out_path = os.path.join(args.output_dir, out_name)

        for rec in tqdm(records, desc=f"GMO Samples x{n_runs}"):
            ad_prompt = rec.get("prompt", "")
            ground_truth = rec.get("response")

            persona_data = {}
            persona_means = []

            for persona_name, persona_sys_prompt in persona_prompts.items():
                predictions: list[float] = []
                responses: list[str] = []

                for _ in range(n_runs):
                    resp_text = _verbalize_persona(ad_prompt, persona_sys_prompt)

                    num_match = re.search(r"(?i)answer[^0-9]{0,10}(\d{1,3}(?:\.\d+)?)", resp_text)
                    score = float(num_match.group(1)) if num_match else None
                    if score is not None:
                        score = max(0.0, min(100.0, score))

                    if score is not None:
                        predictions.append(score)
                    responses.append(resp_text)

                mean_score_persona = float(np.mean(predictions)) if predictions else None
                persona_means.append(mean_score_persona if mean_score_persona is not None else np.nan)

                persona_data[persona_name] = {
                    "predictions": predictions,
                    "mean_prediction": mean_score_persona,
                    "responses": responses,
                }

            # Overall aggregate across personas (ignore NaN)
            overall_mean = float(np.nanmean(persona_means)) if persona_means else None

            run_results.append(
                {
                    "prompt": ad_prompt,
                    "ground_truth": ground_truth,
                    "personas": persona_data,
                    "overall_mean_prediction": overall_mean,
                    "source_file": rec.get("_source_file"),
                }
            )

            # Incremental write after each datapoint to protect against crashes
            try:
                with open(out_path, "w", encoding="utf-8") as f_inc:
                    json.dump(run_results, f_inc, indent=2)
            except Exception as e:
                print(f"[WARNING] Incremental save failed: {e}")

        # Final write for this n_runs
        try:
            with open(out_path, "w", encoding="utf-8") as f_final:
                json.dump(run_results, f_final, indent=2)
        except Exception as e:
            print(f"[ERROR] Final save failed for {out_path}: {e}")

        print(f"[INFO] GMO evaluation with {n_runs} runs complete. Results saved to {out_path}")

# ---------------------------------------------------------------------------
# Short-circuit: run GMO mode and exit (nothing further below).
# ---------------------------------------------------------------------------
if args.gmo:
    run_gmo_evaluation(args)
    sys.exit(0)

# No additional code below this point. 