import argparse
import os

from collections import defaultdict
from itertools import product
from pathlib import Path
from typing import Any, Dict, List, Optional

from .bo_replication import run_one_bo_replication
from .regression_replication import run_one_regression_replication


def run_regression_benchmark(
    results_dir: str,
    n_trials: int,
    method_names: List[str],
    function_name: str,
    outlier_fraction: float,
    outlier_generator_name: str,
    n_train: int,
    n_test: Optional[int] = None,
    noise_std: Optional[float] = None,
    outlier_generator_kwargs: Optional[Dict[str, Any]] = None,
    base_seed: int = 12346,
    cuda: bool = True,
) -> Dict[str, List[str]]:
    return _run_benchmark(
        results_dir=results_dir,
        n_trials=n_trials,
        kind="regression",
        method_names=method_names,
        function_name=function_name,
        outlier_fraction=outlier_fraction,
        outlier_generator_name=outlier_generator_name,
        outlier_generator_kwargs=outlier_generator_kwargs,
        base_seed=base_seed,
        cuda=cuda,
        n_train=n_train,
        n_test=n_test,
        noise_std=noise_std,
    )

def run_bo_benchmark(
    results_dir: str,
    n_trials: int,
    method_names: List[str],
    function_name: str,
    outlier_fraction: float,
    outlier_generator_name: str,
    n_evals: int,
    batch_size: int,
    n_init: Optional[int] = None,
    outlier_generator_kwargs: Optional[Dict[str, Any]] = None,
    base_seed: int = 12346,
    cuda: bool = True,
) -> Dict[str, List[str]]:
    return _run_benchmark(
        results_dir=results_dir,
        n_trials=n_trials,
        kind="BO",
        method_names=method_names,
        function_name=function_name,
        outlier_fraction=outlier_fraction,
        outlier_generator_name=outlier_generator_name,
        outlier_generator_kwargs=outlier_generator_kwargs,
        base_seed=base_seed,
        cuda=cuda,
        n_evals=n_evals,
        n_init=n_init,
        batch_size=batch_size,
    )


def _run_benchmark(
    results_dir: str,
    n_trials: int,
    kind: str,  # either "regression" or "BO"
    method_names: List[str],
    base_seed: int = 12346,
    cuda: bool = True,
    **kwargs,  # BO/regression specific kwargs
) -> Dict[str, List[str]]:
    if kind not in ["regression", "BO"]:
        raise ValueError(f"kind must be either 'regression' or 'BO', got {kind}")
    if kind == "regression":
        run_one_replication = run_one_regression_replication
    else:
        run_one_replication = run_one_bo_replication

    Path(results_dir).mkdir(parents=True, exist_ok=True)

    keys = defaultdict(list)
    results = []
    for i, (method_name, trial_number) in enumerate(
        product(method_names, list(range(n_trials)))
    ):
        seed = base_seed + trial_number
        result_key = f"{method_name}_{trial_number}.pt"
        results_fpath = os.path.join(results_dir, result_key)

        res = run_one_replication(
            results_fpath=results_fpath,
            seed=seed,
            method_name=method_name,
            **kwargs,
        )
        keys[method_name].append(result_key)
        results.append(res)
    return keys
