#!/usr/bin/env python3
# --------------------------------------------------------------------------------------
# Imports & Config
# --------------------------------------------------------------------------------------
import argparse
import csv
import json
import os
import random
import re
import sys
import time
from typing import Dict, List, Optional

import numpy as np
from openai import AzureOpenAI
from datasets import load_dataset
from tqdm import tqdm

# --------------------------------------------------------------------------------------
# Model
# --------------------------------------------------------------------------------------
client: Optional[AzureOpenAI] = None

# --------------------------------------------------------------------------------------
# Embedding model configuration
# --------------------------------------------------------------------------------------
EMBED_MODEL = "text-embedding-3-small"

# Path for global diagnostics (created in output_dir later)
EMBED_DIAG_FILENAME = "embedding_diagnostics.txt"

# Azure OpenAI Configuration (copied from gpt_static_thisruns.py)
api_version = "2024-02-15-preview"
config_dict: Dict[str, str] = {
    "api_key": "YOUR_OPENAI_API_KEY",
    "api_version": api_version,
    "azure_endpoint": "https://your-azure-openai-endpoint/",
}

def _get_embeddings(
    texts: List[str],
    batch_size: int = 96,
    diag_path: Optional[str] = None,
    max_retries: int = 5,
) -> List[Optional[np.ndarray]]:
    """Compute embeddings for *texts* using Azure OpenAI with robust retry logic.

    For each *batch_size* slice of *texts* we attempt up to *max_retries* times;
    on a 429 (rate-limit) or transient network error we back-off exponentially.

    Unlike the previous implementation we *do not* abort the entire run when a
    single batch fails.  Instead, we insert ``None`` placeholders for every text
    in the failed batch so downstream code can still leverage the successful
    embeddings it has and fall back to random sampling only for the missing
    ones.
    """

    # Pre-allocate output so ordering is preserved even with failures.
    results: List[Optional[np.ndarray]] = [None] * len(texts)

    # Build Azure client once (outside the loop for efficiency)
    try:
        client = AzureOpenAI(
            api_key=os.getenv("OPENAI_API_KEY", config_dict["api_key"]),
            api_version=config_dict["api_version"],
            azure_endpoint=config_dict["azure_endpoint"],
        )
    except Exception as e:
        if diag_path:
            with open(diag_path, "a") as f:
                f.write(f"[CLIENT-INIT-ERROR] Failed to create AzureOpenAI client: {e}\n")
        return results  # all None

    for start in range(0, len(texts), batch_size):
        chunk = texts[start : start + batch_size]

        # Retry loop for this chunk only
        attempt = 0
        while attempt <= max_retries:
            try:
                resp = client.embeddings.create(model=EMBED_MODEL, input=chunk)
                resp.data.sort(key=lambda x: x.index)  # preserve original order

                for i, d in enumerate(resp.data):
                    results[start + i] = np.array(d.embedding, dtype=np.float32)

                break  # success, move to next batch

            except Exception as e:
                attempt += 1
                # Parse retry-after seconds if available (Azure puts it in the
                # error message sometimes) – default to 5 * attempt seconds.
                wait_secs = 5 * attempt
                if "retry" in str(e).lower():
                    # crude extraction of the first integer in the message
                    import re as _re

                    m = _re.search(r"retry after (\\d+)", str(e).lower())
                    if m:
                        wait_secs = int(m.group(1))

                if diag_path:
                    with open(diag_path, "a") as f:
                        f.write(
                            f"[EMBEDDING-ERROR] Batch {start}-{start+len(chunk)-1} attempt {attempt}/{max_retries}: {e}. Waiting {wait_secs}s.\n"
                        )

                if attempt > max_retries:
                    # Give up on this batch – leave None placeholders.
                    break

                time.sleep(wait_secs)

        # End retry loop

        # Respect base rate-limit between *successful* calls only.
        if attempt == 0 or results[start] is not None:
            time.sleep(1)  # more conservative than 0.1s previously

    return results


def _cosine(u: np.ndarray, v: np.ndarray) -> float:
    """Cosine similarity between two vectors."""
    return float(np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v) + 1e-8))

# --------------------------------------------------------------------------------------
# Helper functions
# --------------------------------------------------------------------------------------

def verbalize(user_prompt: str) -> str:
    """Call GPT-4o model for chat completion with only user prompt."""
    messages = [{"role": "user", "content": user_prompt}]

    # Generate response
    response = client.chat.completions.create(
        model='gpt-4o',
        messages=messages,
        max_tokens=1200,
        temperature=0.85,
    )
    
    return response.choices[0].message.content.strip()

# --------------------------------------------------------------------------------------
# CLI
# --------------------------------------------------------------------------------------

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Run static evaluation for tweet engagement.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="static_folder", help="Directory to write JSON results.")
    parser.add_argument("--dataset_paths", type=str, required=True, help="Comma-separated list of *.jsonl datasets to evaluate.")
    parser.add_argument("--max_examples", type=int, default=None, help="(Optional) truncate dataset to this many examples – useful for quick smoke tests.")
    parser.add_argument("--similarity_json", type=str, default=None, help="Path to JSON with pre-computed nearest neighbours.")
    return parser.parse_args()

# --------------------------------------------------------------------------------------
# Main evaluation logic (chunk mode)
# --------------------------------------------------------------------------------------

def main() -> None:
    global client
    args = parse_args()

    # -----------------------------------------------------------------------------
    # Initialize client
    # -----------------------------------------------------------------------------
    if client is None:
        client = AzureOpenAI(
            api_key=os.getenv("OPENAI_API_KEY", config_dict["api_key"]),
            api_version=config_dict["api_version"],
            azure_endpoint=config_dict["azure_endpoint"],
        )

    # Tweet evaluation is the only task, so we run it directly.
    sim_map = _get_similarity_map(args)
    run_tweet_evaluation(args, sim_map=sim_map)

def _get_similarity_map(args):
    """Load precomputed similarity map if provided, else return None."""
    sim_map = None
    if args.similarity_json:
        if not os.path.isfile(args.similarity_json):
            print(f"[WARNING] --similarity_json provided but file not found: {args.similarity_json}")
        else:
            with open(args.similarity_json, "r", encoding="utf-8") as _f:
                sim_map = json.load(_f)
            print(f"Loaded pre-computed similarity map from {args.similarity_json} (entries: {len(sim_map)})")
    return sim_map

def _tweet_key(rec, idx):
    """Return a unique key for a tweet record for similarity lookup. By default, use the index as string."""
    return str(idx)

def _extract_brand_and_date(text: str):
    """Very lightweight extraction of brand name and year from *text*.

    This utility tries to identify:
    1. A brand name specified as `Brand: <XYZ>` (case-insensitive, stops at whitespace).
    2. A four-digit year between 1900–2099.

    If either element is not found we fall back to the string "unknown" so that
    downstream accuracy statistics buckets remain well-formed.
    """
    # Brand
    brand_match = re.search(r"brand\s*:\s*([A-Za-z0-9_\-]+)", text, flags=re.IGNORECASE)
    brand = brand_match.group(1).lower() if brand_match else "unknown"

    # Year
    year_match = re.search(r"\b(19|20)\d{2}\b", text)
    year = year_match.group(0) if year_match else "unknown"

    return brand, year

def run_tweet_evaluation(args, sim_map=None):
    """End-to-end evaluation on tweet-like datasets containing {"prompt":..., "response":...} per line.
    If sim_map is provided, use it for few-shot neighbor selection; else fall back to random sampling."""

    if sim_map is None:
        return

    # Resolve dataset paths
    dset_paths = [p.strip() for p in args.dataset_paths.split(",") if p.strip()]


    overall_out_dir = args.output_dir or "tweet_static_results"
    os.makedirs(overall_out_dir, exist_ok=True)

    for dpath in dset_paths:
        dataset_name = os.path.basename(dpath)
        print(f"\n[INFO] Processing dataset: {dataset_name}")

        records = []
        with open(dpath, "r", encoding="utf-8") as f_in:
            for line_idx, line in enumerate(f_in):
                if args.max_examples and line_idx >= args.max_examples:
                    break
                try:
                    records.append(json.loads(line))
                except Exception:
                    continue  # skip malformed

        # --- Apply slicing if --start/--end are provided ---
        slice_start = max(0, args.start) if hasattr(args, 'start') and args.start is not None else 0
        slice_end = args.end if hasattr(args, 'end') and args.end is not None else len(records) - 1
        slice_end = min(slice_end, len(records) - 1)
        if slice_start > 0 or slice_end < len(records) - 1:
            records = records[slice_start : slice_end + 1]
            print(f"[INFO] Processing slice {slice_start}-{slice_end} (n={len(records)}) of {dataset_name}")
        else:
            print(f"[INFO] Processing full dataset {dataset_name} (n={len(records)})")

        slice_suffix = f"_{slice_start}_{slice_end}" if 'slice_start' in locals() else ""
        out_path = os.path.join(overall_out_dir, f"tweet_results_{dataset_name}{slice_suffix}.json")

        correct = 0
        brand_stats = {}
        time_stats = {}
        all_results = []

        # Precompute all possible indices for neighbor selection
        all_indices = list(range(len(records)))

        for idx, rec in enumerate(tqdm(records, desc=dataset_name)):
            prompt_text = rec.get("prompt", "")
            gt_resp = rec.get("response", "")

            gt_label = "high" if re.search(r"high likes", gt_resp, flags=re.IGNORECASE) else "low"

            # --- FEW-SHOT EXAMPLES: Always select 5 random, but ensure mix of high/low ---
            pool = [i for i in all_indices if i != idx]
            max_attempts = 10
            for attempt in range(max_attempts):
                neighbor_ids = random.sample(pool, k=min(5, len(pool)))
                labels = [1 if re.search(r"high likes", records[nid].get("response", ""), flags=re.IGNORECASE) else 0 for nid in neighbor_ids]
                if any(labels) and not all(labels):
                    break  # At least one high and one low
            else:
                # Fallback: force at least one high and one low if possible
                highs = [i for i in pool if re.search(r"high likes", records[i].get("response", ""), flags=re.IGNORECASE)]
                lows = [i for i in pool if not re.search(r"high likes", records[i].get("response", ""), flags=re.IGNORECASE)]
                neighbor_ids = []
                if highs: neighbor_ids.append(random.choice(highs))
                if lows: neighbor_ids.append(random.choice(lows))
                rest = [i for i in pool if i not in neighbor_ids]
                neighbor_ids += random.sample(rest, k=min(5-len(neighbor_ids), len(rest)))
            log_msg = f"[RANDOM_MIXED] Used random mixed neighbors for idx {idx}: {neighbor_ids}"

            example_blocks = []
            for sid in neighbor_ids:
                ex = records[sid]
                text = ex["prompt"]
                # For tweet tasks, we don't have a numeric score, so just show the prompt
                example_blocks.append(f"{text}")
            examples_text = "\n---\n".join(example_blocks)

            user_prompt = (
                "Below are five example tweets. After these, you'll see a new tweet. "
                "Predict whether it will receive high or low likes on Twitter. "
                "Return two lines exactly:\nReason: <brief>\nAnswer: [High / Low]\n"
                "Examples:\n" + examples_text +
                "\n---\n" + prompt_text
            )

            resp_text = verbalize(user_prompt)
            match = re.search(r"Answer:\s*(high|low)", resp_text, flags=re.IGNORECASE)
            pred_label = match.group(1).lower() if match else None

            # Accuracy bookkeeping
            is_correct = pred_label == gt_label
            if is_correct:
                correct += 1

            brand, year = _extract_brand_and_date(prompt_text)

            # Brand stats
            b_stats = brand_stats.setdefault(brand, {"correct": 0, "total": 0})
            b_stats["total"] += 1
            if is_correct:
                b_stats["correct"] += 1

            # Time stats (year)
            t_stats = time_stats.setdefault(year, {"correct": 0, "total": 0})
            t_stats["total"] += 1
            if is_correct:
                t_stats["correct"] += 1

            all_results.append({
                "prompt": prompt_text,
                "ground_truth": gt_label,
                "response": resp_text,
                "predicted_label": pred_label,
                "neighbor_ids": neighbor_ids,
                "neighbor_log": log_msg,
            })

            # Incremental save after every example to avoid data loss
            try:
                with open(out_path, "w", encoding="utf-8") as f_out_inc:
                    json.dump(all_results, f_out_inc, indent=2)
            except Exception as _e:
                print(f"[WARNING] Incremental save failed: {_e}")

        # — Final save per-dataset results
        with open(out_path, "w", encoding="utf-8") as f_out:
            json.dump(all_results, f_out, indent=2)

        # — Report accuracies
        total = len(records)
        overall_acc = correct / total if total else 0.0
        print(f"Overall accuracy for {dataset_name}: {overall_acc:.3f} ({correct}/{total})")

        print("\nAccuracy by brand:")
        for b, st in sorted(brand_stats.items(), key=lambda x: x[0]):
            acc = st["correct"] / st["total"] if st["total"] else 0.0
            print(f"  {b}: {acc:.3f} ({st['correct']}/{st['total']})")

        print("\nAccuracy by year:")
        for y, st in sorted(time_stats.items(), key=lambda x: x[0]):
            acc = st["correct"] / st["total"] if st["total"] else 0.0
            print(f"  {y}: {acc:.3f} ({st['correct']}/{st['total']})")

    print("\n[INFO] Tweet evaluation complete. Exiting.")
    sys.exit(0)

if __name__ == "__main__":
    main()


            

            

        
        
        

        

