import sys
from pathlib import Path
_src = Path(__file__).resolve().parent.parent
if str(_src) not in sys.path:
    sys.path.insert(0, str(_src))
from load_dataset import get_overtonbench_data

import pandas as pd
import prompts, llm_api, helper_functions
import argparse
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
import math
from tenacity import retry, stop_after_attempt, wait_exponential


# Prompt type mapping for easy reference
PROMPT_TYPE_MAP = {
    "fr": "freeresponse",
    "demog": "demog",
    "demog+fr": "demog_freeresponse",
    "fs+fr": "freeresponse_fewshot",
    "fs": "fewshot",
}
PROMPT_NAME_TO_INDEX = {v: k for k, v in PROMPT_TYPE_MAP.items()}


# Log messages with a timestamp
_last_log_time = 0  # module-level, used by log() for throttling
def log(message, force=False, verbose_interval=300):
    global _last_log_time
    now = time.time()

    if force or (now - _last_log_time >= verbose_interval):
        print(f"[{time.strftime('%H:%M:%S')}] {message}")
        _last_log_time = now


# Evaluate a single row for the given client and prompt types
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
def evaluate_row(row, client_name, prompt_types, temperature=0):
    model_name = "gpt-5-mini" if client_name == 'openai' else "gemini-2.5-pro"
    client = None  # OpenAI client is created inside llm_api.generate() when needed
    
    results = {}
    for prompt_type in prompt_types:
        effective_row = row.copy()
        formatted_prompt = helper_functions.format_prompt(prompt_type, effective_row)

        # Generate ratings
        try:
            temp = 0
            if model_name == "gpt-5-mini":
                temp = 1
            rating = llm_api.generate(
                preformatted_prompt=formatted_prompt,
                sys_prompt=prompts.system_prompt,
                client=client,
                model_name=model_name,
                max_retries=1,
                temperature=temp
            )
            
            prompt_id = PROMPT_NAME_TO_INDEX[prompt_type]
            prefix = f"{client_name}_{prompt_id}"

            results[f"{prefix}_avg"] = rating
            if (actual := row.get('representation_rating')) is not None and rating is not None:
                results[f"{prefix}_diff"] = rating - actual
                
        except Exception as e:
            log(f"ERROR: {client_name.upper()} | {prompt_type} | Exception: {str(e)[:100]}", force=True)
            continue
            
    return results

# Process a single row with all specified prompt types
def process_row(row, args, prompt_types):
    """Process a single row of the dataset. Handles both dict and namedtuple inputs."""
    # Convert row to dict
    row_dict = row._asdict() if hasattr(row, '_asdict') else row.copy()
    
    for prompt_type in prompt_types:
        if args.client in ['openai', 'both']:
            row_dict.update(evaluate_row(row_dict, 'openai', [prompt_type]))
        if args.client in ['gemini', 'both']:
            row_dict.update(evaluate_row(row_dict, 'gemini', [prompt_type]))
    return row_dict

# Main function to parse arguments and run evaluations
def main():
    start_time = time.time()

    parser = argparse.ArgumentParser()
    parser.add_argument('--client', required=True, choices=['openai', 'gemini', 'both'])
    parser.add_argument('--n_rows', type=int, default=None)
    parser.add_argument('--prompt', type=str, help='Prompt type to run (name). E.g., "fr" or "demog"')
    parser.add_argument('--prompts', nargs='+', type=str,
                        help='List of prompt types (by name), e.g., --prompts fr demog fs+fr')
    CPU_COUNT = os.cpu_count() or 4
    parser.add_argument('--max_workers', type=int, default=min(CPU_COUNT * 2, 8))
    parser.add_argument('--source', default=None,
                        help="Question source split when loading from HF: full (default), modelslant, or prism. Output filename gets _modelslant/_prism suffix when set.")
    parser.add_argument('--data', default=None,
                        help="Path to CSV to use instead of Hugging Face (overrides DATASET in .env if set). Same schema as OvertonBench.")
    args = parser.parse_args()

    # Pass data options so main load and helper_functions (few-shot ratings/qa) use the same source
    helper_functions.set_data_options(path=args.data, source_split=args.source)

    # Determine prompt_types
    if args.prompts:
        prompt_types = [PROMPT_TYPE_MAP[i] for i in args.prompts]
    elif args.prompt:
        prompt_name = PROMPT_TYPE_MAP[args.prompt]
        prompt_types = [prompt_name]
    else:
        prompt_types = list(PROMPT_NAME_TO_INDEX.keys())

    df = get_overtonbench_data(path=args.data, source_split=args.source)
    if args.n_rows is None:
        sample = df  # take the whole dataset
    else:
        sample = df.sample(min(args.n_rows, len(df)), random_state=42)

    all_results = []
    completed = 0

    # Batch processing parameters
    BATCH_SIZE = max(1, args.max_workers)  # Optimal batch size
    total_batches = math.ceil(len(sample) / BATCH_SIZE)

    try:
        with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
            for batch_num in range(total_batches):
                batch_start = batch_num * BATCH_SIZE
                batch_end = min((batch_num + 1) * BATCH_SIZE, len(sample))
                batch = sample.iloc[batch_start:batch_end]

                futures = {
                    executor.submit(process_row, row.to_dict(), args, prompt_types): idx
                    for idx, row in batch.iterrows()
                }

                for future in as_completed(futures):
                    idx = futures[future]
                    try:
                        all_results.append(future.result())
                        completed = len(all_results)
                        log(f"Progress: {completed}/{len(sample)} ({completed/len(sample):.1%})")
                    except Exception as e:
                        log(f"ERROR in row {idx}: {str(e)[:200]}", force=True)

                # Rate limiting between batches
                if batch_num < total_batches - 1:
                    time.sleep(2)

        # Save results
        os.makedirs('outputs/predictions', exist_ok=True)
        num_rows = "all" if args.n_rows is None else args.n_rows
        filename = f"{args.client}_{num_rows}_rows"
        if args.prompt:
            filename += f"_{args.prompt}"
        source = (args.source or "full").strip().lower()
        if source in ("modelslant", "prism"):
            filename += f"_{source}"
        output_path = f"outputs/predictions/{filename}.csv"

        log(f"Total runtime: {(time.time()-start_time)/60:.1f} minutes", force=True)
        pd.DataFrame(all_results).to_csv(output_path, index=False)
        log(f"Saved results to {output_path}", force=True)
    finally:
        if args.client in ('gemini', 'both'):
            llm_api.close_gemini_client()

# Entry point
if __name__ == "__main__":
    main()