import json
import os
import re
import sys
from typing import Dict, List
import argparse 
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

# --------------------------------------------------------------------------------------
# Model
# --------------------------------------------------------------------------------------
model, tokenizer = None, None

# --------------------------------------------------------------------------------------
# Persona prompts
# --------------------------------------------------------------------------------------
#CTR (Change to CPA for ROAS Prediction)
persona_prompts = {
    "18-24_female": """You are a digital ad analyst who is also a woman in the 18-24 age group. You're deeply familiar with what resonates with your generation—emotional authenticity, aesthetic quality, bold individuality, and social relevance. You instinctively recognize what appeals to younger women: body positivity, mental health awareness, empowerment, humor, and cultural fluency (like meme literacy or TikTok trends).

You've been shown up to five similar ads with their CTR performance (from 0 to 100), and now you must predict how well a new ad will perform with women your age. Use your personal insight, generational awareness, and pattern recognition to evaluate it.

Return:
Answer: [0–100]
Reason: [Why you think this ad will perform that way]""",

    "18-24_male": """You are a digital ad analyst who is also a man in the 18-24 age group. You understand the mindset of young men—seeking confidence, entertainment, edge, and relevance. You're fluent in gaming, influencer culture, memes, and the kind of humor or visual punch that lands with this group.

Given several example ads with performance scores, and a new ad to assess, you must judge how likely it is to capture attention and convert for men your age.

Return:
Answer: [0–100]
Reason: [Why this ad will or won't work for your group]""",

    "25-34_female": """You are a digital ad analyst who is also a woman in the 25-34 age group. You understand the balance this age group seeks—between career ambitions, lifestyle goals, personal growth, and relationships. Ads that show aesthetic clarity, empowerment, self-care, and intelligent value propositions resonate well.

You're given several similar ads and their CTR percentiles, followed by a new one to evaluate. Use your cultural fluency and marketing insight to estimate performance among women in your age group.

Return:
Answer: [0–100]
Reason: [Your age group's likely response and why]""",

    "25-34_male": """You are a digital ad analyst and a man aged 25–34. You know what appeals to this demographic—ads that are ambitious, tech-savvy, direct, a little edgy, and tied to lifestyle aspiration (fitness, career growth, travel, or finance). You are skeptical of fluff, and you appreciate clarity, wit, and efficiency.

Given past ads with known CTR and a new one, predict its success based on how well it aligns with male values in this age range.

Return:
Answer: [0–100]
Reason: [Why you predict this level of engagement]""",

    "35-44_female": """You're a digital ad analyst and a woman aged 35–44. You understand the balance your generation strikes—juggling responsibilities, making informed decisions, but also seeking meaningful and joyful moments. Emotional intelligence, warmth, family, health, and practical luxury matter here.

Given example ads with performance data, evaluate a new ad's resonance with your peers.

Return:
Answer: [0–100]
Reason: [How well the ad speaks to this age group's values]""",

    "35-44_male": """You're a digital ad analyst and a man aged 35–44. You know your generation values trust, utility, and smart messaging. You're discerning, experienced, and don't fall for fluff. You appreciate ads that are well-crafted, respect your intelligence, and deliver real value—especially around career, family, and financial growth.

Given ad examples and a new target ad, predict how well it will land with men in your cohort.

Return:
Answer: [0–100]
Reason: [Why this ad will or won't resonate]""",

    "45-54_female": """You're a digital ad analyst and a woman aged 45–54. You represent a generation that values trust, authenticity, and depth. You're less interested in trendiness and more in whether an ad respects your intelligence and lived experience. You care about wellness, family, quality, and emotional clarity.

Reviewing example ads and a new one, estimate how it will perform among women like you.

Return:
Answer: [0–100]
Reason: [Your interpretation of how this ad fits generational taste]""",

    "45-54_male": """You're a digital ad analyst and a man aged 45–54. You've seen advertising evolve and know when something is thoughtful versus superficial. You appreciate ads that offer clarity, real-world value, and a touch of inspiration. Trust and practical appeal matter most.

Given ad performance examples and a new one, analyze how it will fare with men your age.

Return:
Answer: [0–100]
Reason: [Your logic for the predicted score]""",

    "55+_female": """You are a seasoned digital ad analyst and a woman over 55. You don't just evaluate ads—you see them through decades of shifting media, values, and cultural narratives. You know your peers prefer messaging that's clear, emotionally resonant, warm, and respectful. Themes like wellness, family, community, and security matter deeply.

You're shown similar ad examples with CTRs, followed by a new one. Predict its performance based on emotional tone, clarity, and meaningfulness to older women.

Return:
Answer: [0–100]
Reason: [Why your age group will or won't respond to this ad]""",

    "55+_male": """You are a digital ad analyst and a man aged 55 or older. You've experienced enough advertising to know when something is valuable versus manipulative. You favor sincerity, logic, and emotional grounding. You value health, legacy, family, security, and clarity.

Given example ads and a new one, analyze how it will perform for men your age.

Return:
Answer: [0–100]
Reason: [Your thoughtful rationale]"""
}

# --------------------------------------------------------------------------------------
# Helper functions
# --------------------------------------------------------------------------------------

def verbalize(full_prompt: str, model, tokenizer, args) -> str:
    """Call Qwen model for chat completion with a combined prompt."""
    messages = [
        # No system prompt; the persona is baked into the user prompt
        {"role": "user", "content": "/no_think" + full_prompt},
    ]
    # Use apply_chat_template
    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        enable_thinking=False
    )
    input_ids = input_ids.to(model.device)  # Ensure input_ids are on the same device as the model

    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            max_new_tokens=1200,
            temperature=0.85,
            use_cache=True,
            do_sample=True,
            min_p=0.1
        )
    
    response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
    return response.strip()

# --------------------------------------------------------------------------------------
# CLI
# --------------------------------------------------------------------------------------

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Run GMO evaluation over a slice of a campaign dataset.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="gmo_results", help="Directory to write JSON results.")
    parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use for inference.")
    parser.add_argument("--dataset_paths", type=str, required=True, help="Comma-separated list of *.jsonl datasets to evaluate.")
    parser.add_argument("--max_examples", type=int, default=None, help="(Optional) truncate dataset to this many examples – useful for quick smoke tests.")
    # --tweet_eval is kept for compatibility with the run.sh script, but could be removed
    # if the runner is also updated to remove it.
    parser.add_argument("--tweet_eval", action="store_true", help="Legacy flag to trigger this evaluation path.")
    return parser.parse_args()

# --------------------------------------------------------------------------------------
# Main evaluation logic
# --------------------------------------------------------------------------------------

def main() -> None:
    global model, tokenizer
    args = parse_args()

    # -----------------------------------------------------------------------------
    # Load Model once at the start
    # -----------------------------------------------------------------------------
    model_name = "Qwen/Qwen3-32B" #change to meta-llama/Llama-3.3-70B-Instruct for LlaMA
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto",
        load_in_4bit=True,
    )

    # The script now only performs one task, so we can call it directly.
    run_gmo_evaluation(args)


def run_gmo_evaluation(args):
    global model, tokenizer
    # Resolve dataset paths
    if args.dataset_paths:
        dset_paths = [p.strip() for p in args.dataset_paths.split(",") if p.strip()]
    else:
        # This path should not be taken if run via run.sh
        print("[ERROR] --dataset_paths is a required argument.", file=sys.stderr)
        sys.exit(1)

    overall_out_dir = args.output_dir
    os.makedirs(overall_out_dir, exist_ok=True)

    for dpath in dset_paths:
        dataset_name = os.path.basename(dpath)
        print(f"\n[INFO] Processing dataset: {dataset_name}")

        records = []
        with open(dpath, "r", encoding="utf-8") as f_in:
            for line_idx, line in enumerate(f_in):
                if args.max_examples and line_idx >= args.max_examples:
                    break
                try:
                    records.append(json.loads(line))
                except Exception:
                    continue  # skip malformed

        # --- Apply slicing if --start/--end are provided ---
        slice_start = max(0, args.start) if hasattr(args, 'start') and args.start is not None else 0
        slice_end = args.end if hasattr(args, 'end') and args.end is not None else len(records) - 1
        slice_end = min(slice_end, len(records) - 1)
        if slice_start > 0 or slice_end < len(records) - 1:
            records = records[slice_start : slice_end + 1]
            print(f"[INFO] Processing slice {slice_start}-{slice_end} (n={len(records)}) of {dataset_name}")
        else:
            print(f"[INFO] Processing full dataset {dataset_name} (n={len(records)})")

        slice_suffix = f"_{slice_start}_{slice_end}"
        # Use a more specific output filename to avoid clashes
        out_path = os.path.join(overall_out_dir, f"gmo_results_{dataset_name}{slice_suffix}.json")

        all_results = []

        for idx, rec in enumerate(tqdm(records, desc=dataset_name)):
            # The 'prompt' field from the dataset contains the ICL examples and query.
            ad_prompt = rec.get("prompt", "")
            gt_resp = rec.get("response", None)  # ground-truth CTR percentile if provided
            log_msg = "[INFO] Using combined persona and ad prompt."

            persona_outputs = {}
            score_list = []
            for persona_name, persona_text in persona_prompts.items():
                # Combine the persona and the ad-specific prompt into one.
                full_prompt = f"{persona_text}\n\n{ad_prompt}"
                
                resp_text = verbalize(full_prompt, model, tokenizer, args)

                # Robustly extract the first 0-100 number that comes after the word "Answer"
                # Allows for extra markdown symbols like **, *, _ etc.
                num_match = re.search(r"(?i)answer[^0-9]{0,10}(\d{1,3}(?:\.\d+)?)", resp_text)
                score = float(num_match.group(1)) if num_match else None
                if score is not None:
                    score = max(0.0, min(100.0, score))
                    score_list.append(score)
                persona_outputs[persona_name] = {"response": resp_text, "score": score}

            avg_score = sum(score_list) / len(score_list) if score_list else None

            all_results.append({
                "prompt": ad_prompt,
                "ground_truth": gt_resp,
                "persona_predictions": persona_outputs,
                "avg_predicted_score": avg_score,
                "log": log_msg,
            })

            # Incremental save after every example to avoid data loss
            try:
                with open(out_path, "w", encoding="utf-8") as f_out_inc:
                    json.dump(all_results, f_out_inc, indent=2)
            except Exception as _e:
                print(f"[WARNING] Incremental save failed: {_e}")

        # — Final save per-dataset results
        with open(out_path, "w", encoding="utf-8") as f_out:
            json.dump(all_results, f_out, indent=2)

        print(f"[INFO] Completed processing {dataset_name} slice {slice_start}-{slice_end}. Results saved to {out_path}")

    print("\n[INFO] GMO ad evaluation complete.")

if __name__ == "__main__":
    main()


            

            

        
        
        

        
