"""
GEPA baseline for rule arena (US tax) domain - Prompt-to-improve version

This version uses the initial calculation as context, then prompts the model to improve it.

    python -m baselines.gepa.gepa_rule_arena_improve \
        --minibatch-size 3 \
        --max-calls 10240 \
        --train-size 50 \
        --val-size 25 \
        --num-threads 20 \
        --log-dir baselines/gepa/logs/rule_arena_improve \
        --num-runs 5 \
        --output-dir baselines/gepa/results \
        --wandb-name rule_arena_improve_gepa_paper \
        --temperature 1.0
"""

import dspy
from dspy.evaluate import Evaluate
from dspy import GEPA
import pandas as pd
import random
import os
import statistics
import json
import argparse
from pathlib import Path

from advisor_models.rule_arena.config import (
    build_prompt,
    compute_score,
)
from utils.eval_utils import compute_multi_run_statistics

random.seed(42)
llm = dspy.LM("openai/gpt-4.1-mini", cache=False, temperature=1.0)
dspy.settings.configure(lm=llm)


def load_rule_arena_data():
    """Load rule arena data from parquet files."""
    train_path = "data/rule_arena/train_gpt-4.1-mini_0.parquet"
    val_path = "data/rule_arena/validation_gpt-4.1-mini_0.parquet"

    if not os.path.exists(train_path) or not os.path.exists(val_path):
        raise FileNotFoundError(
            f"Rule arena data files not found at {train_path} or {val_path}"
        )

    train_df = pd.read_parquet(train_path)
    val_df = pd.read_parquet(val_path)

    # Convert to DSPy format
    trainset = []
    for _, row in train_df.iterrows():
        info_dict = row["info_dict"]
        ground_truth = row["reward_spec"]["ground_truth"]
        original_response = row.get("original_response", "")

        # Build the problem prompt
        problem = build_prompt(info_dict)

        example = dspy.Example(
            prompt=problem,
            ground_truth=ground_truth,
            original_response=original_response,
        ).with_inputs("prompt", "ground_truth", "original_response")
        trainset.append(example)

    valset = []
    for _, row in val_df.iterrows():
        info_dict = row["info_dict"]
        ground_truth = row["reward_spec"]["ground_truth"]
        original_response = row.get("original_response", "")

        # Build the problem prompt
        problem = build_prompt(info_dict)

        example = dspy.Example(
            prompt=problem,
            ground_truth=ground_truth,
            original_response=original_response,
        ).with_inputs("prompt", "ground_truth", "original_response")
        valset.append(example)

    return trainset, valset


class ImprovedTaxCalculation(dspy.Signature):
    """Improve the initial tax calculation."""

    prompt = dspy.InputField(desc="The original tax calculation problem")
    original_response = dspy.InputField(desc="The original calculation from the data")
    improved_calculation = dspy.OutputField(
        desc="The improved tax calculation with reasoning"
    )


class TaxCalculatorImproveModule(dspy.Module):
    """Tax calculator module with improvement step using original response from data."""

    def __init__(self):
        super().__init__()
        self.improve = dspy.ChainOfThought(ImprovedTaxCalculation)

    def forward(self, prompt, ground_truth=None, original_response=None):
        # Use the original response from the data file
        if not original_response:
            raise ValueError("original_response is required but not provided")

        # Improve the calculation based on the original response
        improved = self.improve(prompt=prompt, original_response=original_response)

        return improved


def compute_score_metric(example, pred, trace=None):
    """Compute the reward score for a prediction."""
    response = pred.improved_calculation
    ground_truth = example.ground_truth

    # Compute accuracy score
    reward, _ = compute_score(response, ground_truth)
    return reward


def scalar_feedback_metric(example, pred, trace=None, *args, **kwargs):
    """Provide scalar feedback for GEPA optimization."""
    reward = compute_score_metric(example, pred, trace)
    return reward


def evaluate_model(model, dataset, model_name, num_threads=72):
    """Evaluate a model on the given dataset."""
    print(f"\n=== Evaluating {model_name} ===")

    # Evaluate on subset for faster testing
    eval_dataset = random.sample(dataset, min(100, len(dataset)))

    evaluator = Evaluate(
        devset=eval_dataset,
        metric=compute_score_metric,
        num_threads=num_threads,
        display_progress=True,
    )

    eval_result = evaluator(model)
    score = eval_result.score
    results = [entry[2] for entry in eval_result.results]

    # Calculate standard error
    reward_se = (
        statistics.stdev(results) / (len(results) ** 0.5) if len(results) > 1 else 0
    )

    print(f"Average accuracy: {score:.4f}±{reward_se:.4f}")

    return results


def run_multi_evaluation(model, dataset, num_runs, num_threads=72):
    """Run multiple evaluations and compute statistics."""
    all_run_scores = []

    for run_idx in range(num_runs):
        print(f"\n=== Run {run_idx + 1}/{num_runs} ===")
        scores = evaluate_model(model, dataset, f"Run {run_idx + 1}", num_threads)
        all_run_scores.append(scores)

    return all_run_scores


def save_optimized_prompt(model, output_dir, domain_name):
    """Save the optimized prompt."""
    os.makedirs(output_dir, exist_ok=True)

    # Save model
    model_path = Path(output_dir) / f"{domain_name}_optimized_model.json"
    model.save(str(model_path))
    print(f"Saved optimized model to {model_path}")


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        description="GEPA baseline for rule arena domain (prompt-to-improve version)"
    )

    parser.add_argument(
        "--minibatch-size",
        type=int,
        default=3,
        help="Reflection minibatch size (default: 3)",
    )

    parser.add_argument(
        "--max-calls",
        type=int,
        default=64000,
        help="Maximum number of metric calls (default: 64000)",
    )

    parser.add_argument(
        "--train-size",
        type=int,
        default=100,
        help="Number of training examples to use (default: 100)",
    )

    parser.add_argument(
        "--val-size",
        type=int,
        default=50,
        help="Number of validation examples to use (default: 50)",
    )

    parser.add_argument(
        "--eval-size",
        type=int,
        default=100,
        help="Number of evaluation examples to use (default: 100)",
    )

    parser.add_argument(
        "--num-threads",
        type=int,
        default=72,
        help="Number of threads for parallel execution (default: 72)",
    )

    parser.add_argument(
        "--log-dir",
        type=str,
        default="gepa_logs_rule_arena_improve",
        help="Directory for GEPA logs",
    )
    parser.add_argument(
        "--num-runs",
        type=int,
        default=5,
        help="Number of evaluation runs for final evaluation",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="baselines/gepa/results",
        help="Directory to save results and optimized prompts",
    )

    parser.add_argument(
        "--wandb-name",
        type=str,
        default="rule_arena_improve_gepa",
        help="W&B run name (default: rule_arena_improve_gepa)",
    )

    parser.add_argument("--no-wandb", action="store_true", help="Disable W&B logging")

    parser.add_argument(
        "--temperature", type=float, default=1.0, help="LLM temperature (default: 1.0)"
    )

    return parser.parse_args()


def main():
    args = parse_args()

    print("Configuration:")
    print(f"  Minibatch size: {args.minibatch_size}")
    print(f"  Max metric calls: {args.max_calls}")
    print(f"  Train size: {args.train_size}")
    print(f"  Val size: {args.val_size}")
    print(f"  Eval size: {args.eval_size}")
    print(f"  Num threads: {args.num_threads}")
    print(f"  Temperature: {args.temperature}")
    print(f"  Log directory: {args.log_dir}")
    print(f"  W&B enabled: {not args.no_wandb}")
    if not args.no_wandb:
        print(f"  W&B run name: {args.wandb_name}")
    print()

    # Configure LLM
    llm = dspy.LM("openai/gpt-4.1-mini", cache=False, temperature=args.temperature)
    dspy.settings.configure(lm=llm)

    # Load data
    print("Loading rule arena data...")
    trainset, valset = load_rule_arena_data()
    print(
        f"Loaded {len(trainset)} training examples, {len(valset)} validation examples"
    )

    # Initialize model
    model = TaxCalculatorImproveModule()
    print("Running GEPA optimization...")

    # Prepare datasets
    random.shuffle(trainset)
    train_subset = trainset[: args.train_size]
    val_subset = trainset[args.train_size : args.train_size + args.val_size]
    eval_subset = random.sample(valset, min(args.eval_size, len(valset)))

    # Configure GEPA
    gepa_kwargs = {
        "metric": scalar_feedback_metric,
        "max_metric_calls": args.max_calls,
        "num_threads": args.num_threads,
        "track_stats": True,
        "reflection_minibatch_size": args.minibatch_size,
        "reflection_lm": dspy.LM(
            model="openai/gpt-4.1-mini", temperature=args.temperature, max_tokens=4000
        ),
        "log_dir": args.log_dir,
    }

    # Add W&B configuration if enabled
    if not args.no_wandb:
        gepa_kwargs.update(
            {
                "use_wandb": True,
                "wandb_init_kwargs": {
                    "entity": "bare-sky",
                    "project": "advisor-models-baselines",
                    "name": args.wandb_name,
                },
                "wandb_api_key": os.getenv("WANDB_API_KEY"),
            }
        )

    gepa = GEPA(**gepa_kwargs)

    optimized_model = gepa.compile(model, trainset=train_subset, valset=val_subset)

    print("Optimized prompt:")
    for name, pred in optimized_model.named_predictors():
        print("================================")
        print(f"Predictor: {name}")
        print("================================")
        print("Prompt:")
        print(pred.signature.instructions)
        print("*********************************")

    # Save optimized prompts
    save_optimized_prompt(optimized_model, args.output_dir, "rule_arena_improve")

    # Run multi-run evaluation
    print(f"\nRunning {args.num_runs} evaluation runs...")
    all_run_scores = run_multi_evaluation(
        optimized_model, eval_subset, args.num_runs, args.num_threads
    )

    # Compute and report statistics
    stats = compute_multi_run_statistics(all_run_scores)
    print("\n=== Final Evaluation Statistics ===")
    print(f"Mean: {stats['mean']:.4f}")
    print(f"SEM: {stats['sem']:.4f}")
    print(
        f"95% Bootstrap CI: [{stats['bootstrap_ci_lower']:.4f}, {stats['bootstrap_ci_upper']:.4f}]"
    )
    print(f"Number of runs: {args.num_runs}")

    # Save results
    os.makedirs(args.output_dir, exist_ok=True)
    results_file = (
        Path(args.output_dir) / f"rule_arena_improve_gepa_{args.num_runs}runs.json"
    )
    with open(results_file, "w") as f:
        json.dump(
            {
                "domain": "rule_arena_improve",
                "num_runs": args.num_runs,
                "num_samples": len(eval_subset),
                "statistics": stats,
            },
            f,
            indent=2,
        )
    print(f"Saved results to {results_file}")


if __name__ == "__main__":
    main()
