import copy
import os
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, field
from itertools import product
from pathlib import Path
from typing import Iterator

import psutil
import yaml
from argparse_dataclass import ArgumentParser

from wmcal.experiments import ExperimentConfig, run_experiment
from wmcal.utils import get_logger


@dataclass
class Args:
    cfg: str = field(metadata={"help": "Path to YAML sweep config file"})
    redo: bool = field(
        default=False, metadata={"help": "Rerun experiments even if already completed"}
    )
    debug: bool = field(default=False, metadata={"help": "Enable debug logging"})


@dataclass
class SweepConfig:
    """Configuration for a parameter sweep."""

    base_config_path: str
    n_seeds: int
    workers: int | float = 1
    sweep_params: dict = field(default_factory=dict)
    override_config: dict = field(default_factory=dict)

    @classmethod
    def from_dict(cls, config_dict: dict) -> "SweepConfig":
        return cls(
            base_config_path=config_dict["base_config_path"],
            n_seeds=config_dict["n_seeds"],
            workers=config_dict.get("workers", 1),
            sweep_params=config_dict.get("sweep_params", {}),
            override_config=config_dict.get("override_config", {}),
        )


def load_base_config(path: str) -> dict:
    with open(path, "r") as f:
        return yaml.safe_load(f)


def _apply_override_recursive(config: dict, override_key: str, override_value) -> None:
    """Recursively override all matching keys in config dict.

    Args:
        config: The config dict to modify in-place.
        override_key: The key to override (may be nested path with dots).
        override_value: The value to set.
    """
    if "." in override_key:
        keys = override_key.split(".")
        temp = config
        for key in keys[:-1]:
            if key not in temp:
                return
            temp = temp[key]
        temp[keys[-1]] = override_value
    else:
        for key, value in list(config.items()):
            if key == override_key:
                config[key] = override_value
            elif isinstance(value, dict):
                _apply_override_recursive(value, override_key, override_value)


def is_experiment_completed(exp_id: str) -> bool:
    done_file = Path(".logs") / exp_id / "done"
    return done_file.exists()


def get_experiment_configs(sweep_config: SweepConfig) -> Iterator[ExperimentConfig]:
    """Generate all experiment configs for a sweep.

    Args:
        sweep_config: The sweep configuration.

    Yields:
        ExperimentConfig for each parameter combination.
    """
    base_config = load_base_config(sweep_config.base_config_path)

    # Flatten sweep_params
    def flatten_sweep_params(d: dict, prefix: str = "") -> list[tuple[str, list]]:
        result = []
        for k, v in d.items():
            path = f"{prefix}.{k}" if prefix else k
            if isinstance(v, dict):
                result.extend(flatten_sweep_params(v, path))
            elif isinstance(v, list):
                result.append((path, v))
            else:
                raise ValueError(f"Sweep param at {path} must be a list, got {type(v)}")
        return result

    sweep_items = flatten_sweep_params(sweep_config.sweep_params)
    override_items = flatten_sweep_params(sweep_config.override_config)

    # Validate sweep_params
    for path, value_list in sweep_items:
        keys = path.split(".")
        temp = base_config
        for key in keys:
            if key not in temp:
                raise ValueError(f"Path {path} not found in base config")
            temp = temp[key]
        base_value = temp
        base_type = type(base_value)
        for val in value_list:
            if not isinstance(val, base_type):
                raise ValueError(
                    f"Sweep param at {path} has {type(val)}, expected {base_type}"
                )

    # Get sweep lists (sweep_params + override_config)
    all_items = sweep_items + override_items
    sweep_lists = [value_list for _, value_list in all_items]

    # Product over all params
    for combo in product(*sweep_lists):
        config = copy.deepcopy(base_config)

        # Apply sweep_params (with paths)
        for (path, _), val in zip(sweep_items, combo[: len(sweep_items)]):
            keys = path.split(".")
            temp = config
            for key in keys[:-1]:
                temp = temp[key]
            temp[keys[-1]] = val

        # Apply override_config (recursive, without paths)
        for (override_key, _), override_val in zip(
            override_items, combo[len(sweep_items) :]
        ):
            _apply_override_recursive(config, override_key, override_val)

        # Now, for each seed
        for seed in [42 + i for i in range(sweep_config.n_seeds)]:
            config_copy = copy.deepcopy(config)
            config_copy["seed"] = seed
            # Add metric_configs to match old experiment hashes
            if "metric_configs" not in config_copy:
                config_copy["metric_configs"] = [{"type": "mvs"}]
            yield ExperimentConfig.from_dict(config_copy)


def get_experiment_ids(sweep_config: SweepConfig) -> list[str]:
    """Get all experiment IDs for a sweep.

    Args:
        sweep_config: The sweep configuration.

    Returns:
        List of experiment IDs.
    """
    return [exp.id for exp in get_experiment_configs(sweep_config)]


@dataclass
class RunResult:
    """Result of a single experiment run."""

    exp_id: str
    status: str  # "completed", "skipped", "failed"
    duration_s: float = 0.0
    error: str | None = None


def get_machine_stats() -> dict:
    """Get current machine resource stats."""
    return {
        "cpu_count": os.cpu_count(),
        "cpu_percent": psutil.cpu_percent(interval=0.1),
        "ram_total_gb": psutil.virtual_memory().total / (1024**3),
        "ram_available_gb": psutil.virtual_memory().available / (1024**3),
    }


def run_single_experiment(
    exp_config: ExperimentConfig, redo: bool = False, debug: bool = False
) -> RunResult:
    """Run a single experiment in a worker process.

    Args:
        exp_config: The experiment configuration.
        redo: If True, rerun even if already completed.
        debug: If True, enable debug logging.

    Returns:
        RunResult with status and any error message.
    """
    exp_id = exp_config.id
    start = time.perf_counter()

    if not redo and is_experiment_completed(exp_id):
        return RunResult(exp_id=exp_id, status="skipped")

    try:
        success = run_experiment(exp_config.to_dict(), redo=redo, debug=debug)
        duration = time.perf_counter() - start
        if success:
            return RunResult(exp_id=exp_id, status="completed", duration_s=duration)
        else:
            return RunResult(
                exp_id=exp_id,
                status="failed",
                duration_s=duration,
                error="run returned False",
            )
    except Exception as e:
        duration = time.perf_counter() - start
        return RunResult(
            exp_id=exp_id, status="failed", duration_s=duration, error=str(e)
        )


def run_sweep(
    sweep_config: SweepConfig,
    dry_run: bool = False,
    redo: bool = False,
    debug: bool = False,
) -> None:
    """Run all experiments in a sweep.

    Args:
        sweep_config: The sweep configuration.
        dry_run: If True, only log what would be run without executing.
        redo: If True, rerun experiments even if already completed.
        debug: If True, enable debug logging.
    """
    logger = get_logger(__name__)

    experiments = list(get_experiment_configs(sweep_config))
    total = len(experiments)

    if dry_run:
        logger.info("DRY RUN MODE - Not actually running experiments")
        for i, exp_config in enumerate(experiments, 1):
            logger.info(
                f"[{i}/{total}] DRY RUN: seed={exp_config.seed} [id={exp_config.id[:8]}...]"
            )
        logger.info(f"Summary: {total} experiments (dry run)")
        return

    logger.info(f"Running {total} experiments...")

    completed = 0
    skipped = 0
    failed = 0

    for i, exp_config in enumerate(experiments, 1):
        exp_id = exp_config.id

        if not redo and is_experiment_completed(exp_id):
            logger.info(
                f"[{i}/{total}] SKIP: seed={exp_config.seed} (already done) [id={exp_id[:8]}...]"
            )
            skipped += 1
            continue

        logger.info(f"[{i}/{total}] RUN: seed={exp_config.seed} [id={exp_id[:8]}...]")

        try:
            success = run_experiment(exp_config.to_dict(), redo=redo, debug=debug)
            if success:
                completed += 1
            else:
                failed += 1
        except Exception as e:
            logger.error(f"FAILED: {e}")
            failed += 1

    logger.info(f"Summary: {completed} completed, {skipped} skipped, {failed} failed")


def run_sweep_parallel(
    sweep_config: SweepConfig,
    max_workers: int | None = None,
    dry_run: bool = False,
    redo: bool = False,
    debug: bool = False,
) -> None:
    """Run all experiments in a sweep using parallel workers.

    Args:
        sweep_config: The sweep configuration.
        max_workers: Maximum number of parallel workers. Defaults to all CPU cores.
        dry_run: If True, only log what would be run without executing.
        redo: If True, rerun experiments even if already completed.
        debug: If True, enable debug logging.
    """
    logger = get_logger(__name__)

    # Log machine stats
    stats = get_machine_stats()
    logger.info(
        f"Machine: {stats['cpu_count']} cores, "
        f"{stats['ram_total_gb']:.1f}GB RAM ({stats['ram_available_gb']:.1f}GB free)"
    )

    experiments = list(get_experiment_configs(sweep_config))
    total = len(experiments)

    # Use all cores by default - let it rip!
    if max_workers is None:
        max_workers = os.cpu_count() or 4

    if dry_run:
        logger.info(
            f"DRY RUN MODE - Would run {total} experiments with {max_workers} workers"
        )
        for i, exp_config in enumerate(experiments, 1):
            logger.info(
                f"[{i}/{total}] DRY RUN: seed={exp_config.seed} [id={exp_config.id[:8]}...]"
            )
        return

    logger.info(f"Running {total} experiments with {max_workers} workers (FULL SEND)")

    sweep_start = time.perf_counter()
    completed = 0
    skipped = 0
    failed = 0
    total_experiment_time = 0.0

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        future_to_exp = {
            executor.submit(run_single_experiment, exp, redo, debug): exp
            for exp in experiments
        }

        for future in as_completed(future_to_exp):
            exp_config = future_to_exp[future]
            done_count = completed + skipped + failed + 1
            try:
                result = future.result()
                total_experiment_time += result.duration_s
                if result.status == "completed":
                    completed += 1
                    logger.info(
                        f"[{done_count}/{total}] DONE in {result.duration_s:.1f}s "
                        f"[id={result.exp_id[:8]}...]"
                    )
                elif result.status == "skipped":
                    skipped += 1
                    logger.info(
                        f"[{done_count}/{total}] SKIP [id={result.exp_id[:8]}...]"
                    )
                else:
                    failed += 1
                    logger.error(
                        f"[{done_count}/{total}] FAIL [id={result.exp_id[:8]}...] "
                        f"{result.error}"
                    )
            except Exception as e:
                failed += 1
                logger.error(
                    f"[{done_count}/{total}] FAIL [id={exp_config.id[:8]}...] {e}"
                )

    sweep_duration = time.perf_counter() - sweep_start
    speedup = total_experiment_time / sweep_duration if sweep_duration > 0 else 0

    logger.info(f"{'=' * 60}")
    logger.info("SWEEP COMPLETE")
    logger.info(f"  Results: {completed} completed, {skipped} skipped, {failed} failed")
    logger.info(f"  Wall time: {sweep_duration:.1f}s")
    logger.info(f"  Total CPU time: {total_experiment_time:.1f}s")
    logger.info(f"  Parallelism speedup: {speedup:.1f}x")
    logger.info(f"{'=' * 60}")


def main():
    dry_run = os.getenv("DRY_RUN", "false").lower() == "true"
    args = parser.parse_args()
    with open(args.cfg, "r") as f:
        sweep_config_dict = yaml.safe_load(f)
    sweep_config = SweepConfig.from_dict(sweep_config_dict)

    # Allow overriding workers via env var
    workers_env = os.getenv("WORKERS")
    if workers_env:
        try:
            if "." in workers_env:
                sweep_config.workers = float(workers_env)
            else:
                sweep_config.workers = int(workers_env)
        except ValueError:
            pass  # ignore invalid

    cpu_count = os.cpu_count() or 4

    if sweep_config.workers == 1:
        run_sweep(sweep_config, dry_run=dry_run, redo=args.redo, debug=args.debug)
    else:
        # Calculate max_workers
        if isinstance(sweep_config.workers, float):
            if not (0 < sweep_config.workers <= 1):
                raise ValueError(
                    f"Float workers must be between 0 and 1, got {sweep_config.workers}"
                )
            max_workers = int(sweep_config.workers * cpu_count)
        elif isinstance(sweep_config.workers, int):
            if sweep_config.workers > cpu_count:
                raise ValueError(
                    f"Workers {sweep_config.workers} > cpu_count {cpu_count}"
                )
            max_workers = sweep_config.workers
        else:
            raise ValueError(f"Invalid workers type: {type(sweep_config.workers)}")

        run_sweep_parallel(sweep_config, max_workers, dry_run, args.redo, args.debug)


if __name__ == "__main__":
    parser = ArgumentParser(Args)
    main()
