"""
Config-file-based universal trainer CLI for contextual bandit experiments.

This script runs experiments by loading parameters from a JSON config file.
Any parameter in the config can be overridden by a command-line argument.

Usage:
  # Run an experiment defined in a config file
  python -m your_module.run_experiment --config configs/experiment_config.json

  # Override specific parameters for a quick test
  python -m your_module.run_experiment --config configs/experiment_config.json --learning_rate 1e-6 --n_epochs 1
"""

import argparse
import dataclasses
import json
import sys
from pathlib import Path
from typing import Any, Dict

# Import from the same package using relative imports
from .universal_bandit_optimizer import NormConstraintType
from .universal_data_loader import (
    create_arguana_data_loader,
    create_fiqa_data_loader,
    create_multihop_data_loader,
    create_nfcorpus_data_loader,
    create_toolret_data_loader,
    create_ultratool_data_loader,
)
from .universal_experiment_manager import (
    UniversalBanditConfig,
    run_universal_bandit_experiment,
)


def setup_arg_parser() -> argparse.ArgumentParser:
    """Sets up the argument parser for the CLI."""
    parser = argparse.ArgumentParser(
        description="Config-based universal trainer for contextual bandit experiments."
    )

    # --- Primary Argument: Config File ---
    parser.add_argument(
        "--config",
        type=Path,
        required=True,
        help="Path to the JSON experiment configuration file.",
    )

    # --- Override Arguments ---
    # Add all other arguments, but with `default=None`. This allows us to
    # detect if they were set on the command line.

    # Data parameters
    parser.add_argument(
        "--dataset", type=str, default=None, choices=["toolret", "ultratool"]
    )
    parser.add_argument("--tools_dataset_path", type=str, default=None)
    parser.add_argument("--queries_dataset_path", type=str, default=None)
    parser.add_argument(
        "--embedding_model",
        type=str,
        default=None,
        choices=["ada", "small", "large"],
    )
    parser.add_argument(
        "--true_embedding_model",
        type=str,
        default=None,
        choices=["ada", "small", "large"],
    )
    parser.add_argument(
        "--subset",
        type=str,
        default=None,
        choices=["code", "web", "customized"],
    )
    parser.add_argument("--max_queries", type=int, default=None)
    parser.add_argument("--add_noise", action="store_true", default=None)
    parser.add_argument("--noise_std", type=float, default=None)

    # Training parameters
    parser.add_argument("--n_epochs", type=int, default=None)
    parser.add_argument("--batch_size", type=int, default=None)
    parser.add_argument("--learning_rate", type=float, default=None)
    parser.add_argument("--temperature", type=float, default=None)
    parser.add_argument("--epsilon", type=float, default=None)
    parser.add_argument("--lambda_reg", type=float, default=None)
    parser.add_argument("--clip_value", type=float, default=None)

    # Optimizer parameters
    parser.add_argument(
        "--norm_constraint",
        type=str,
        default=None,
        choices=["none", "unit_sphere", "unit_vector"],
    )
    parser.add_argument("--beta1", type=float, default=None)
    parser.add_argument("--beta2", type=float, default=None)
    parser.add_argument("--optimizer_epsilon", type=float, default=None)

    # Evaluation and tracking
    parser.add_argument("--eval_interval", type=int, default=None)
    parser.add_argument("--recall_k", type=int, default=None)
    parser.add_argument("--track_interval", type=int, default=None)

    # Reproducibility and Output
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="Output file path for results (JSON format).",
    )

    return parser


def load_and_merge_config(args: argparse.Namespace) -> Dict[str, Any]:
    """Loads a JSON config file and merges it with CLI arguments."""
    config_path = args.config
    if not config_path.is_file():
        print(f"❌ Error: Config file not found at {config_path}")
        sys.exit(1)

    # Load base configuration from the JSON file
    with open(config_path, "r") as f:
        config = json.load(f)
    print(f"✅ Loaded base configuration from {config_path}")

    # Merge CLI arguments as overrides
    cli_args = vars(args)
    overridden_keys = []
    for key, value in cli_args.items():
        # A value is considered a user-provided override if it's not None.
        # For 'add_noise', the default is None, and action='store_true'
        # makes it True if present, so this check works correctly.
        if key != "config" and value is not None:
            config[key] = value
            overridden_keys.append(key)

    if overridden_keys:
        print(
            f"🔧 Overriding config with CLI args: {', '.join(overridden_keys)}"
        )

    return config


def run_experiment(config: Dict[str, Any]):
    """Sets up and runs a single bandit experiment from a config dictionary."""
    dataset_name = config.get("dataset")
    if not dataset_name:
        raise ValueError(
            "Config must specify a 'dataset' ('toolret' or 'ultratool')."
        )

    print("\n" + "=" * 50)
    print(f"🚀 Starting Experiment: {dataset_name.upper()}")
    print("=" * 50)

    # 1. Create Data Loader based on config
    if dataset_name == "toolret":
        data_loader = create_toolret_data_loader(
            tools_dataset_path=config.get("tools_dataset_path"),
            queries_dataset_path=config.get("queries_dataset_path"),
            max_queries=config.get("max_queries"),
            subset=config.get("subset", "code"),
        )
    elif dataset_name == "ultratool":
        data_loader = create_ultratool_data_loader(
            tools_dataset_path=config.get("tools_dataset_path"),
            queries_dataset_path=config.get("queries_dataset_path"),
            max_queries=config.get("max_queries"),
        )
    elif dataset_name == "nfcorpus":
        data_loader = create_nfcorpus_data_loader(
            tools_dataset_path=config.get("tools_dataset_path"),
            queries_dataset_path=config.get("queries_dataset_path"),
            max_queries=config.get("max_queries"),
        )
    elif dataset_name == "arguana":
        data_loader = create_arguana_data_loader(
            tools_dataset_path=config.get("tools_dataset_path"),
            queries_dataset_path=config.get("queries_dataset_path"),
            max_queries=config.get("max_queries"),
        )
    elif dataset_name == "fiqa":
        data_loader = create_fiqa_data_loader(
            tools_dataset_path=config.get("tools_dataset_path"),
            queries_dataset_path=config.get("queries_dataset_path"),
            max_queries=config.get("max_queries"),
        )
    elif dataset_name == "multihop":
        data_loader = create_multihop_data_loader(
            tools_dataset_path=config.get("tools_dataset_path"),
            queries_dataset_path=config.get("queries_dataset_path"),
            max_queries=config.get("max_queries"),
        )
    else:
        raise ValueError(f"Unknown dataset in config: {dataset_name}")

    # 2. Create UniversalBanditConfig from the dictionary
    # Filter the config dict to only include keys expected by the dataclass
    config_fields = {f.name for f in dataclasses.fields(UniversalBanditConfig)}
    dataclass_kwargs = {k: v for k, v in config.items() if k in config_fields}

    # Handle special case for Enum conversion
    if "norm_constraint" in dataclass_kwargs:
        dataclass_kwargs["norm_constraint"] = NormConstraintType(
            dataclass_kwargs["norm_constraint"]
        )

    bandit_config = UniversalBanditConfig(**dataclass_kwargs)

    # 3. Run the experiment
    _, _, _, final_metrics = run_universal_bandit_experiment(
        bandit_config, data_loader
    )

    # # 4. Process and save results
    # print("\n" + "=" * 50)
    # print("📈 Final Metrics")
    # print("=" * 50)
    # print(
    #     f"  Final Recall@1: {final_metrics.get('final_recall_at_1', 'N/A'):.4f}"
    # )
    # print(
    #     f"  Final Recall@k: {final_metrics.get('final_recall_at_k', 'N/A'):.4f}"
    # )
    # print(f"  Average Reward: {final_metrics.get('avg_reward', 'N/A'):.4f}")

    # output_path = config.get("output")
    # if output_path:
    #     # The experiment runner already saves detailed metrics, this is just a confirmation.
    #     print(
    #         f"\n✅ Detailed results and histories saved to: {bandit_config.metrics_output_path}"
    #     )

    print("\n🎉 Experiment Complete!")


def main():
    """Main CLI entry point."""
    parser = setup_arg_parser()
    args = parser.parse_args()

    try:
        final_config = load_and_merge_config(args)
        run_experiment(final_config)
    except Exception as e:
        print(f"\n❌ An error occurred during the experiment: {e}")
        # For debugging, you might want to print the full traceback
        import traceback

        traceback.print_exc()
        sys.exit(1)


if __name__ == "__main__":
    main()
