import argparse
import json
import os
import re
import sys
from typing import Dict, List

from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model, tokenizer = None, None

# --------------------------------------------------------------------------------------
# Moderator Prompt
# --------------------------------------------------------------------------------------
TWEET_GENERATION_PROMPT = """You are a master social-media strategist. You will receive an ad input and several example tweets.

Your objectives:
1.  Synthesize the key insights from the examples.
2.  Generate FOUR distinct, high-quality tweets that cater to the ad input. Each tweet should explore a slightly different angle or tone.
3.  Keep each tweet ≤ 280 characters.
4.  Ensure the tweets are professional, engaging, and use hashtags effectively.

Return ONLY a JSON-formatted list of 4 strings, where each string is a generated tweet.
Example format:
[
    "This is the first tweet, leveraging a direct and bold tone. #Innovation",
    "Here's a second option, focusing more on the emotional and community aspect. #Together",
    "A third tweet, using a question to drive engagement. What do you think? #Future",
    "The fourth and final tweet, more professional and corporate in style. #Official"
]

Do not include any other text, analysis, or commentary in your response."""


# --------------------------------------------------------------------------------------
# Helper functions
# --------------------------------------------------------------------------------------

def verbalize(full_prompt: str, model, tokenizer, args) -> str:
    """Generate a response using the Qwen model."""
    messages = [
        {"role": "user", "content": "/no_think" + full_prompt},
    ]

    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        enable_thinking=False
    )
    input_ids = input_ids.to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            max_new_tokens=1200,
            temperature=0.85,
            use_cache=True,
            do_sample=True,
            min_p=0.1
        )

    response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
    return response.strip()

# --------------------------------------------------------------------------------------
# CLI
# --------------------------------------------------------------------------------------

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Run GMO evaluation over a slice of a campaign dataset.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="gmo_results", help="Directory to write JSON results.")
    parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use for inference.")
    parser.add_argument("--dataset_paths", type=str, required=True, help="Comma-separated list of *.jsonl datasets to evaluate.")
    parser.add_argument("--max_examples", type=int, default=None, help="(Optional) truncate dataset to this many examples – useful for quick smoke tests.")
    # --tweet_eval is kept for compatibility with the run.sh script, but could be removed
    # if the runner is also updated to remove it.
    parser.add_argument("--tweet_eval", action="store_true", help="Legacy flag to trigger this evaluation path.")
    return parser.parse_args()

# --------------------------------------------------------------------------------------
# Main evaluation logic
# --------------------------------------------------------------------------------------

def main() -> None:
    global model, tokenizer
    args = parse_args()

    # -----------------------------------------------------------------------------
    # Load Qwen model once at startup
    # -----------------------------------------------------------------------------
    model_name = "Qwen/Qwen3-32B" #change to meta-llama/Llama-3.3-70B-Instruct for LlaMA
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto",
        load_in_4bit=True,
    )

    # The script now only performs one task, so we can call it directly.
    run_gmo_evaluation(args)


def run_gmo_evaluation(args):
    """
    End-to-end evaluation on GMO datasets containing {"prompt":..., "response":...} per line.
    The 'prompt' field contains ICL examples and the final query.
    """
    global model, tokenizer
    # Resolve dataset paths
    if args.dataset_paths:
        dset_paths = [p.strip() for p in args.dataset_paths.split(",") if p.strip()]
    else:
        # This path should not be taken if run via run.sh
        print("[ERROR] --dataset_paths is a required argument.", file=sys.stderr)
        sys.exit(1)

    overall_out_dir = args.output_dir
    os.makedirs(overall_out_dir, exist_ok=True)

    # Pre-load all similarity data to decide which file to use inside the loop
    script_dir = os.path.dirname(__file__) or "."
    sim_paths = {
        "comp": os.path.join(script_dir, "similarity_test_tweet_comp_con.jsonl"),
        "ran": os.path.join(script_dir, "similarity_test_tweet_ran_con.jsonl")
    }
    sim_data_cache = {}
    def load_sim(path):
        if path not in sim_data_cache:
            print(f"[INFO] Loading similarity data from: {path}")
            try:
                with open(path, "r", encoding="utf-8") as _f:
                    sim_data_cache[path] = json.load(_f)
            except (FileNotFoundError, json.JSONDecodeError) as e:
                print(f"[ERROR] Failed to load similarity file {path}: {e}", file=sys.stderr)
                sim_data_cache[path] = {} # Return empty dict to avoid crashing
        return sim_data_cache[path]

    # Determine which similarity file corresponds to which dataset path
    # This is brittle; assumes 'comp' is in first dset path and 'ran' in second.
    dset_to_sim_key = {}
    if len(dset_paths) > 0:
        dset_to_sim_key[dset_paths[0]] = "comp"
        if len(dset_paths) > 1:
            dset_to_sim_key[dset_paths[1]] = "ran"

    similarity_sets = {
        "comp": load_sim(sim_paths["comp"]),
        "ran": load_sim(sim_paths["ran"])
    }

    for dpath in dset_paths:
        dataset_name = os.path.basename(dpath)
        print(f"\n[INFO] Processing dataset: {dataset_name}")

        # Determine which similarity data to use for this dataset
        sim_key = dset_to_sim_key.get(dpath, "comp") # Default to 'comp'
        similarity_data = similarity_sets[sim_key]
        print(f"[INFO] Using '{sim_key}' similarity data for this dataset.")

        records = []
        with open(dpath, "r", encoding="utf-8") as f_in:
            for line_idx, line in enumerate(f_in):
                if args.max_examples and line_idx >= args.max_examples:
                    break
                try:
                    records.append(json.loads(line))
                except Exception:
                    continue  # skip malformed

        # --- Apply slicing if --start/--end are provided ---
        slice_start = max(0, args.start) if hasattr(args, 'start') and args.start is not None else 0
        slice_end = args.end if hasattr(args, 'end') and args.end is not None else len(records) - 1
        slice_end = min(slice_end, len(records) - 1)
        if slice_start > 0 or slice_end < len(records) - 1:
            records = records[slice_start : slice_end + 1]
            print(f"[INFO] Processing slice {slice_start}-{slice_end} (n={len(records)}) of {dataset_name}")
        else:
            print(f"[INFO] Processing full dataset {dataset_name} (n={len(records)})")

        slice_suffix = f"_{slice_start}_{slice_end}"
        # Use a more specific output filename to avoid clashes
        out_path = os.path.join(overall_out_dir, f"gmo_results_{dataset_name}{slice_suffix}.json")

        all_results = []

        for idx, rec in enumerate(tqdm(records, desc=dataset_name)):
            ad_prompt = rec.get("prompt", "")
            gt_resp_tweet = rec.get("response", "")

            # Retrieve most similar examples based on pre-computed similarity JSON
            record_id_str = str(rec.get("id", idx + slice_start))
            similar_examples = similarity_data.get(record_id_str, [])

            # Take up to 5 top-scoring examples
            few_shot_lines = [
                f"Input: {ex.get('prompt', '')}\nTweet: {ex.get('response', '').replace('<hyperlink>', '').strip()}"
                for ex in similar_examples[:5]
            ]
            few_shot_context = (
                "You are given up to 5 similar examples. Study them carefully.\n\n" + "\n\n".join(few_shot_lines)
            )

            # STAGE 1: Generate reasons from personas
            # ---- REPLACED: Non-persona flow ----
            # ---- STAGE 1: Generate 4 tweets ----
            generation_prompt = (
                f"{TWEET_GENERATION_PROMPT}\n\n"
                f"{few_shot_context}\n\n"
                f"Ad Input: {ad_prompt}\n"
                f"Generate your list of 4 tweets now:"
            )
            generated_tweets_str = verbalize(generation_prompt, model, tokenizer, args)
            generated_tweets = []
            try:
                cleaned_str = re.sub(r"```json\n?|\n?```", "", generated_tweets_str)
                generated_tweets = json.loads(cleaned_str)
                if not isinstance(generated_tweets, list) or len(generated_tweets) != 4:
                    print(f"[WARNING] Expected 4 tweets but got: {generated_tweets}")
                    generated_tweets = ["" for _ in range(4)]
            except json.JSONDecodeError:
                print(f"[ERROR] Failed to parse JSON: {generated_tweets_str}")
                generated_tweets = ["" for _ in range(4)]

            # ---- Save results for this record ----
            all_results.append({
                "original_prompt": ad_prompt,
                "ground_truth_tweet": gt_resp_tweet,
                "generated_tweets": generated_tweets,
            })

            # Incremental save after every example to avoid data loss
            try:
                with open(out_path, "w", encoding="utf-8") as f_out_inc:
                    json.dump(all_results, f_out_inc, indent=2)
            except Exception as _e:
                print(f"[WARNING] Incremental save failed: {_e}")

        # — Final save per-dataset results
        with open(out_path, "w", encoding="utf-8") as f_out:
            json.dump(all_results, f_out, indent=2)

        print(f"[INFO] Completed processing {dataset_name} slice {slice_start}-{slice_end}. Results saved to {out_path}")

    print("\n[INFO] GMO ad evaluation complete.")

if __name__ == "__main__":
    main()