"""
Universal trainer CLI for contextual bandit experiments.

This script provides a flexible command-line interface for running bandit
experiments on different datasets with configurable parameters.
"""

import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict

from universal_bandit_optimizer import NormConstraintType
from universal_data_loader import (
    create_toolret_data_loader,
    create_ultratool_data_loader,
)
from universal_experiment_manager import (
    UniversalBanditConfig,
    run_universal_bandit_experiment,
)


def create_config_from_args(args: argparse.Namespace) -> UniversalBanditConfig:
    """Create experiment configuration from command line arguments."""
    config = UniversalBanditConfig()

    # Data parameters
    config.embedding_model = args.embedding_model
    config.true_embedding_model = args.true_embedding_model
    config.max_queries = args.max_queries
    config.add_noise = args.add_noise
    config.noise_std = args.noise_std

    # Training parameters
    config.n_epochs = args.n_epochs
    config.batch_size = args.batch_size
    config.learning_rate = args.learning_rate
    config.temperature = args.temperature
    config.epsilon = args.epsilon
    config.lambda_reg = args.lambda_reg
    config.clip_value = args.clip_value

    # Optimizer parameters
    config.norm_constraint = NormConstraintType(args.norm_constraint)
    config.beta1 = args.beta1
    config.beta2 = args.beta2
    config.optimizer_epsilon = args.optimizer_epsilon

    # Evaluation parameters
    config.eval_interval = args.eval_interval
    config.recall_k = args.recall_k

    # Parameter tracking
    config.track_interval = args.track_interval

    # Reproducibility
    config.seed = args.seed

    return config


def run_toolret_experiment(args: argparse.Namespace) -> Dict[str, Any]:
    """Run experiment on ToolRet dataset."""
    print(
        f"Running ToolRet experiment: {args.embedding_model} embeddings, {args.subset} subset"
    )

    # Create data loader
    data_loader = create_toolret_data_loader(
        tools_dataset_path=args.tools_dataset_path,
        queries_dataset_path=args.queries_dataset_path,
        max_queries=args.max_queries,
        subset=args.subset,
        tool_text_field=args.tool_text_field,
        query_text_field=args.query_text_field,
    )

    # Create configuration
    config = create_config_from_args(args)

    # Run experiment
    rewards_history, regrets_history, recall_history, final_metrics = (
        run_universal_bandit_experiment(config, data_loader)
    )

    return {
        "dataset": "toolret",
        "subset": args.subset,
        "embedding_model": args.embedding_model,
        "rewards_history": rewards_history,
        "regrets_history": regrets_history,
        "recall_history": recall_history,
        "final_metrics": final_metrics,
    }


def run_ultratool_experiment(args: argparse.Namespace) -> Dict[str, Any]:
    """Run experiment on UltraTool dataset."""
    print(f"Running UltraTool experiment: {args.embedding_model} embeddings")

    # Create data loader
    data_loader = create_ultratool_data_loader(
        tools_dataset_path=args.tools_dataset_path,
        queries_dataset_path=args.queries_dataset_path,
        max_queries=args.max_queries,
    )

    # Create configuration
    config = create_config_from_args(args)

    # Run experiment
    rewards_history, regrets_history, recall_history, final_metrics = (
        run_universal_bandit_experiment(config, data_loader)
    )

    return {
        "dataset": "ultratool",
        "embedding_model": args.embedding_model,
        "rewards_history": rewards_history,
        "regrets_history": regrets_history,
        "recall_history": recall_history,
        "final_metrics": final_metrics,
    }


def print_results(results: Dict[str, Any]):
    """Print experiment results."""
    print(f"\n=== Experiment Results ===")
    print(f"Dataset: {results['dataset']}")
    if "subset" in results:
        print(f"Subset: {results['subset']}")
    print(f"Embedding Model: {results['embedding_model']}")

    final_metrics = results["final_metrics"]
    print(
        f"Final Recall@1: {final_metrics.get('final_recall_at_1', 'N/A'):.4f}"
    )
    print(
        f"Final Recall@k: {final_metrics.get('final_recall_at_k', 'N/A'):.4f}"
    )
    print(f"Average Reward: {final_metrics.get('avg_reward', 'N/A'):.4f}")
    print(
        f"Total Queries Processed: {final_metrics.get('total_queries_processed', 'N/A')}"
    )

    if "parameter_regrets" in final_metrics:
        param_regrets = final_metrics["parameter_regrets"]
        if param_regrets:
            avg_param_regret = sum(r for _, r in param_regrets) / len(
                param_regrets
            )
            print(f"Average Parameter Regret: {avg_param_regret:.4f}")


def save_results(results: Dict[str, Any], output_path: str):
    """Save results to JSON file."""

    # Convert torch tensors and other non-serializable objects to lists
    def convert_for_json(obj):
        if hasattr(obj, "tolist"):
            return obj.tolist()
        elif isinstance(obj, (list, tuple)):
            return [convert_for_json(item) for item in obj]
        elif isinstance(obj, dict):
            return {key: convert_for_json(value) for key, value in obj.items()}
        else:
            return obj

    serializable_results = convert_for_json(results)

    with open(output_path, "w") as f:
        json.dump(serializable_results, f, indent=2)

    print(f"Results saved to {output_path}")


def main():
    """Main CLI function."""
    parser = argparse.ArgumentParser(
        description="Universal trainer for contextual bandit experiments"
    )

    # Dataset selection
    parser.add_argument(
        "dataset",
        choices=["toolret", "ultratool"],
        help="Dataset to use for the experiment",
    )

    # Data parameters
    parser.add_argument(
        "--tools_dataset_path",
        type=str,
        help="Path to tools dataset (auto-detected if not provided)",
    )
    parser.add_argument(
        "--queries_dataset_path",
        type=str,
        help="Path to queries dataset (auto-detected if not provided)",
    )
    parser.add_argument(
        "--embedding_model",
        type=str,
        default="large",
        choices=["ada", "small", "large"],
        help="Embedding model to use (default: large)",
    )
    parser.add_argument(
        "--true_embedding_model",
        type=str,
        default="large",
        choices=["ada", "small", "large"],
        help="True embedding model for regret computation (default: large)",
    )
    parser.add_argument(
        "--subset",
        type=str,
        default="code",
        choices=["code", "web", "customized"],
        help="Dataset subset to use (ToolRet only, default: code)",
    )
    parser.add_argument(
        "--max_queries",
        type=int,
        default=None,
        help="Maximum number of queries to use (default: all)",
    )
    parser.add_argument(
        "--add_noise",
        action="store_true",
        help="Add noise to initial embeddings",
    )
    parser.add_argument(
        "--noise_std",
        type=float,
        default=0.0,
        help="Standard deviation of noise to add (default: 0.1)",
    )

    # Training parameters
    parser.add_argument(
        "--n_epochs",
        type=int,
        default=3,
        help="Number of training epochs (default: 1)",
    )
    parser.add_argument(
        "--batch_size", type=int, default=20, help="Batch size (default: 10)"
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-7,
        help="Learning rate (default: 1e-7)",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=1.0,
        help="Temperature for softmax policy (default: 1.0)",
    )
    parser.add_argument(
        "--epsilon",
        type=float,
        default=0.0,
        help="Epsilon for epsilon-greedy exploration (default: 0.0)",
    )
    parser.add_argument(
        "--lambda_reg",
        type=float,
        default=0.0,
        help="L2 regularization coefficient (default: 0.0)",
    )
    parser.add_argument(
        "--clip_value",
        type=float,
        default=10.0,
        help="IPS weight clipping value for stability (default: 10.0)",
    )

    # Optimizer parameters
    parser.add_argument(
        "--norm_constraint",
        type=str,
        default="unit_sphere",
        choices=["none", "unit_sphere", "unit_vector"],
        help="Norm constraint type (default: unit_sphere)",
    )
    parser.add_argument(
        "--beta1",
        type=float,
        default=0.9,
        help="Adam beta1 parameter (default: 0.9)",
    )
    parser.add_argument(
        "--beta2",
        type=float,
        default=0.999,
        help="Adam beta2 parameter (default: 0.999)",
    )
    parser.add_argument(
        "--optimizer_epsilon",
        type=float,
        default=1e-8,
        help="Adam epsilon parameter (default: 1e-8)",
    )

    # Evaluation parameters
    parser.add_argument(
        "--eval_interval",
        type=int,
        default=40,
        help="Evaluation interval in batches (default: 40)",
    )
    parser.add_argument(
        "--recall_k",
        type=int,
        default=10,
        help="K for recall@k evaluation (default: 10)",
    )

    # Parameter tracking
    parser.add_argument(
        "--track_interval",
        type=int,
        default=20,
        help="Parameter tracking interval in steps (default: 20)",
    )

    # Reproducibility
    parser.add_argument(
        "--seed", type=int, default=42, help="Random seed (default: 42)"
    )

    # Output
    parser.add_argument(
        "--output", type=str, help="Output file path for results (JSON format)"
    )

    args = parser.parse_args()

    # Set default dataset paths if not provided
    if args.tools_dataset_path is None:
        if args.dataset == "toolret":
            args.tools_dataset_path = "embeddings/toolret_tools_embedded"
        elif args.dataset == "ultratool":
            args.tools_dataset_path = "embeddings/ultratool_tools_embedded"

    if args.queries_dataset_path is None:
        if args.dataset == "toolret":
            args.queries_dataset_path = "embeddings/toolret_queries_embedded"
        elif args.dataset == "ultratool":
            args.queries_dataset_path = "embeddings/ultratool_queries_embedded"

    print("Starting Universal Contextual Bandit Experiment")
    print("=" * 50)

    try:
        # Run experiment based on dataset
        if args.dataset == "toolret":
            results = run_toolret_experiment(args)
        elif args.dataset == "ultratool":
            results = run_ultratool_experiment(args)
        else:
            raise ValueError(f"Unknown dataset: {args.dataset}")

        # Print results
        print_results(results)

        # Save results if output path provided
        if args.output:
            save_results(results, args.output)

        print("\n=== Experiment Complete ===")

    except Exception as e:
        print(f"Error running experiment: {e}")
        sys.exit(1)


if __name__ == "__main__":
    main()
