import argparse
import json
import time
import numpy as np
import torch
import random
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Any

from generate_unified import UnifiedConfig, run_unified_pipeline
from run_sweep import build_feature_flags, cleanup_run_outputs

VALID_METRICS = ["gcd", "netlsd_heat", "netlsd_wave", "portrait_div"]


@dataclass
class RunResult:
    seed: int
    metric: str
    diversity: float
    training_time_sec: float
    postprocess_time_sec: float
    total_time_sec: float


def load_wandb_config(config_path: str) -> dict:
    with open(config_path, "r") as f:
        raw_config = json.load(f)
    config = {}
    for key, val in raw_config.items():
        if isinstance(val, dict) and "value" in val:
            config[key] = val["value"]
        else:
            config[key] = val
    return config


def config_to_unified(config: dict, seed: int, output_dir: str, metric: str) -> UnifiedConfig:
    n_vertices = config.get("n_vertices", 16)
    num_ensembles = config.get("num_ensembles", 20)
    batch_size = config.get("batch_size", 50)
    projection_dim = config.get("projection_dim", 4)
    hidden_dim = config.get("hidden_dim", 256)
    num_hidden = config.get("num_hidden", 6)
    num_iterations = config.get("num_iterations", 5000)
    training_budget = config.get("training_budget", 10000)
    graphs_per_ensemble = int(training_budget / num_ensembles)
    postprocess_strategy = config.get("postprocess_strategy", "iterative_survival")
    n_top_matrices = config.get("n_top_matrices", 1000)
    sampling_budget = config.get("sampling_budget", 100000)
    direction_seed = config.get("direction_seed", 42)
    feature_flags = build_feature_flags(config)

    return UnifiedConfig(
        num_ensembles=num_ensembles,
        batch_size=batch_size,
        projection_dim=projection_dim,
        n_vertices=n_vertices,
        hidden_dim=hidden_dim,
        num_hidden=num_hidden,
        graphs_per_ensemble=graphs_per_ensemble,
        num_iterations=num_iterations,
        feature_flags=feature_flags,
        selection_metrics=[metric],
        k_select=config.get("k_select", 100),
        selection_objective=config.get("selection_objective", "average"),
        postprocess_strategy=postprocess_strategy,
        n_top_matrices=n_top_matrices,
        sampling_budget=sampling_budget,
        output_dir=output_dir,
        seed=seed,
        direction_seed=direction_seed,
    )


def run_single_trial(config: dict, seed: int, run_idx: int, metric: str) -> RunResult:
    output_dir = f"data/reproduce_run_{run_idx}_seed_{seed}"
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    unified_config = config_to_unified(config, seed, output_dir, metric)

    from generate_unified import (
        train_ensemble,
        save_intermediate,
        run_all_selections,
    )
    from distances_spectral import FeatureConfig as DistFeatureConfig

    direction_seed = config.get("direction_seed", 42)
    feature_config = DistFeatureConfig(unified_config.feature_flags, direction_seed=direction_seed)

    print(f"\n{'='*60}")
    print(f"RUN {run_idx + 1} (seed={seed}): TRAINING")
    print(f"{'='*60}")

    train_start = time.time()
    collected_graphs, collected_probs = train_ensemble(unified_config, feature_config)
    train_end = time.time()
    training_time = train_end - train_start

    _, all_probs = save_intermediate(collected_graphs, collected_probs, unified_config)

    print(f"\n{'='*60}")
    print(f"RUN {run_idx + 1} (seed={seed}): POSTPROCESSING ({metric})")
    print(f"{'='*60}")

    postprocess_start = time.time()
    results = run_all_selections(all_probs, unified_config)
    postprocess_end = time.time()
    postprocess_time = postprocess_end - postprocess_start

    total_time = training_time + postprocess_time

    result_obj = results.get(metric)
    if hasattr(result_obj, 'diversity'):
        diversity = result_obj.diversity
    elif isinstance(result_obj, dict) and 'diversity' in result_obj:
        diversity = result_obj['diversity']
    else:
        diversity = 0.0

    run_result = RunResult(
        seed=seed,
        metric=metric,
        diversity=diversity,
        training_time_sec=training_time,
        postprocess_time_sec=postprocess_time,
        total_time_sec=total_time,
    )

    cleanup_run_outputs(output_dir)

    print(f"\nRUN {run_idx + 1} COMPLETE:")
    print(f"  {metric} diversity: {run_result.diversity:.6f}")
    print(f"  Training time: {training_time:.1f}s")
    print(f"  Postprocess time: {postprocess_time:.1f}s")
    print(f"  Total time: {total_time:.1f}s")

    return run_result


def compute_summary_stats(values: list[float]) -> dict:
    arr = np.array(values)
    return {
        "mean": float(np.mean(arr)),
        "std": float(np.std(arr)),
        "min": float(np.min(arr)),
        "max": float(np.max(arr)),
        "median": float(np.median(arr)),
    }


def main():
    parser = argparse.ArgumentParser(
        description="Reproduce a wandb run multiple times for error bars"
    )
    parser.add_argument(
        "--config",
        type=str,
        required=True,
        help="Path to wandb config JSON file"
    )
    parser.add_argument(
        "--metric",
        type=str,
        required=True,
        choices=VALID_METRICS,
        help=f"Which metric to evaluate. Options: {VALID_METRICS}"
    )
    parser.add_argument(
        "--runs",
        type=int,
        default=5,
        help="Number of runs (default: 5)"
    )
    parser.add_argument(
        "--output",
        type=str,
        default="reproduce_results.json",
        help="Output JSON file (default: reproduce_results.json)"
    )
    parser.add_argument(
        "--start-seed",
        type=int,
        default=1,
        help="Starting seed (default: 1, will use 1, 2, 3, ...)"
    )
    parser.add_argument(
        "--ablation",
        type=str,
        choices=["no_experts", "no_directions"],
        default=None,
        help="Ablation mode: 'no_experts' sets num_ensembles=1, 'no_directions' sets projection_dim=1"
    )
    args = parser.parse_args()

    print(f"Loading config from {args.config}")
    config = load_wandb_config(args.config)

    if args.ablation == "no_experts":
        original_ensembles = config.get("num_ensembles", 20)
        config["num_ensembles"] = 1
        print(f"\n⚠️  ABLATION: no_experts - overriding num_ensembles from {original_ensembles} to 1")
    elif args.ablation == "no_directions":
        original_proj_dim = config.get("projection_dim", 4)
        config["projection_dim"] = 1
        print(f"\n⚠️  ABLATION: no_directions - overriding projection_dim from {original_proj_dim} to 1")

    print(f"\nConfig summary:")
    print(f"  n_vertices: {config.get('n_vertices', 16)}")
    print(f"  num_ensembles: {config.get('num_ensembles', 20)}")
    print(f"  projection_dim: {config.get('projection_dim', 4)}")
    print(f"  training_budget: {config.get('training_budget', 10000)}")
    print(f"  postprocess_strategy: {config.get('postprocess_strategy', 'iterative_survival')}")
    print(f"  direction_seed: {config.get('direction_seed', 42)}")
    print(f"\nMetric: {args.metric}")
    if args.ablation:
        print(f"Ablation: {args.ablation}")

    all_results: list[RunResult] = []
    seeds = list(range(args.start_seed, args.start_seed + args.runs))

    print(f"\nRunning {args.runs} trials with seeds: {seeds}")

    for i, seed in enumerate(seeds):
        result = run_single_trial(config, seed, i, args.metric)
        all_results.append(result)

    summary = {
        "diversity": compute_summary_stats([r.diversity for r in all_results]),
        "training_time_sec": compute_summary_stats([r.training_time_sec for r in all_results]),
        "postprocess_time_sec": compute_summary_stats([r.postprocess_time_sec for r in all_results]),
        "total_time_sec": compute_summary_stats([r.total_time_sec for r in all_results]),
    }

    output = {
        "config": config,
        "metric": args.metric,
        "ablation": args.ablation,
        "num_runs": args.runs,
        "seeds": seeds,
        "runs": [asdict(r) for r in all_results],
        "summary": summary,
    }

    output_path = Path(args.output)
    with open(output_path, "w") as f:
        json.dump(output, f, indent=2)

    print(f"\n{'='*60}")
    print("FINAL SUMMARY")
    print(f"{'='*60}")
    print(f"Results saved to: {output_path}")
    print(f"\nMetric: {args.metric}")
    print(f"  Diversity: {summary['diversity']['mean']:.6f} ± {summary['diversity']['std']:.6f}")
    print(f"  (min={summary['diversity']['min']:.6f}, max={summary['diversity']['max']:.6f})")
    print(f"\nTiming (mean ± std):")
    print(f"  Training:     {summary['training_time_sec']['mean']:.1f} ± {summary['training_time_sec']['std']:.1f} sec")
    print(f"  Postprocess:  {summary['postprocess_time_sec']['mean']:.1f} ± {summary['postprocess_time_sec']['std']:.1f} sec")
    print(f"  Total:        {summary['total_time_sec']['mean']:.1f} ± {summary['total_time_sec']['std']:.1f} sec")


if __name__ == "__main__":
    main()
