from __future__ import annotations

import argparse
import importlib
import importlib.util
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Iterable, List, Sequence, Optional

from .environment import ExperimentEnvironment


BanditFactory = Callable[[], Any]
AlgorithmFactory = Callable[[], Any]


@dataclass
class ExperimentBuilder:

    bandit_factory: BanditFactory
    algorithm_factories: Sequence[AlgorithmFactory] = field(default_factory=list)
    algorithm_names: Sequence[str] = field(default_factory=list)
    environment_kwargs: dict[str, Any] = field(default_factory=dict)
    environment_cls: Callable[..., Any] = ExperimentEnvironment
    n_mc: int = 1
    n_jobs: int = -1
    output: dict[str, Any] = field(default_factory=dict)

    def build(self) -> tuple[ExperimentEnvironment, int, int]:
        bandit = self.bandit_factory()
        env_kwargs = dict(self.environment_kwargs)
        env_kwargs.setdefault("bandit", bandit)
        env_kwargs.setdefault("algorithm_factories", list(self.algorithm_factories))
        if self.algorithm_names:
            env_kwargs.setdefault("algorithm_names", list(self.algorithm_names))
        environment = self.environment_cls(**env_kwargs)
        return environment, self.n_mc, self.n_jobs


def run(builder: ExperimentBuilder) -> dict[str, Any]:
    environment, n_mc, n_jobs = builder.build()
    avg_regret, std_regret, timings, detections, delays = environment.run_experiment(
        n_mc=n_mc, n_jobs=n_jobs
    )
    return {
        "avg_regret": avg_regret,
        "std_regret": std_regret,
        "timings": timings,
        "detections": detections,
        "detection_delays": delays,
    }


def _import_module_from_path(path: Path):
    spec = importlib.util.spec_from_file_location(path.stem, path)
    if spec is None or spec.loader is None:
        raise ImportError(f"Cannot import module from {path}")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module) 
    return module


def load_builder(source: str) -> ExperimentBuilder:
    path = Path(source)
    if path.exists():
        module = _import_module_from_path(path.resolve())
    else:
        module = importlib.import_module(source)
    builder = getattr(module, "CONFIG", None)
    if not isinstance(builder, ExperimentBuilder):
        raise ValueError(f"{source} does not define CONFIG=ExperimentBuilder")
    return builder


def main(argv: Optional[Iterable[str]] = None) -> dict[str, Any]:
    parser = argparse.ArgumentParser(description="Run DAL experiments from config modules.")
    parser.add_argument(
        "--config",
        required=True,
        help="Python module or file path exporting CONFIG=ExperimentBuilder.",
    )
    args = parser.parse_args(list(argv) if argv is not None else None)
    builder = load_builder(args.config)
    summary = run(builder)
    for name, regret in summary["avg_regret"].items():
        final_value = float(regret[-1])
        print(f"{name}: final regret = {final_value:.3f}")
    return summary


if __name__ == "__main__":
    main()
