import argparse
import yaml
import pickle
import json
import shutil
from pathlib import Path
from datetime import datetime

from generate_unified import UnifiedConfig, run_unified_pipeline


OUTPUT_DIR = Path("data/k_scaling")

ALL_K_CONFIGS = [
    "config_k1024.yaml",
    "config_k2048.yaml",
    "config_k4096.yaml",
    "config_k8192.yaml",
]


def load_yaml_config(yaml_path: str) -> dict:
    with open(yaml_path, "r") as f:
        raw_config = yaml.safe_load(f)
    params = raw_config.get("parameters", {})
    config = {}
    for key, val in params.items():
        if isinstance(val, dict) and "value" in val:
            config[key] = val["value"]
        elif isinstance(val, dict) and "values" in val:
            config[key] = val["values"][0]
        else:
            config[key] = val
    return config


def build_feature_flags(config: dict) -> dict[str, bool]:
    feature_flags = {
        "adj_m3_m2": True,
        "adj_m4_m2": True,
        "adj_m5_m3": True,
        "adj_m6_m4": True,
        "adj_m4_m3": True,
        "adj_m6_m2": True,
        "adj_m2_norm": config.get("feat_adj_m2_norm", False),
        "adj_m5_m2": config.get("feat_adj_m5_m2", False),
        "adj_m6_m3": config.get("feat_adj_m6_m3", False),
        "adj_m5_m4": config.get("feat_adj_m5_m4", False),
        "regularity_proxy": config.get("feat_regularity_proxy", False),
        "spectral_spread": config.get("feat_spectral_spread", False),
        "clustering_proxy": config.get("feat_clustering_proxy", False),
        "triangle_density": config.get("feat_triangle_density", False),
        "adj_m3_norm": config.get("feat_adj_m3_norm", False),
        "adj_m4_norm": config.get("feat_adj_m4_norm", False),
        "lap_m2_m1": False,
        "lap_m3_m2": False,
        "lap_m4_m2": False,
        "lap_m4_m3": False,
        "adj_lap_m2": False,
        "adj_lap_m4": False,
        "lap_m2_norm": False,
    }
    return feature_flags


def run_k_scaling(yaml_path: str, dry_run: bool = False):
    print(f"\n{'=' * 70}")
    print(f"K-SCALING EXPERIMENT: {yaml_path}")
    print(f"{'=' * 70}")

    config = load_yaml_config(yaml_path)

    n_vertices = config["n_vertices"]
    k_select = config["k_select"]
    num_ensembles = config["num_ensembles"]
    batch_size = config["batch_size"]
    projection_dim = config["projection_dim"]
    hidden_dim = config["hidden_dim"]
    num_hidden = config["num_hidden"]
    postprocess_strategy = config.get("postprocess_strategy", "iterative_survival")
    n_top_matrices = config.get("n_top_matrices", 1500)
    sampling_budget = config.get("sampling_budget", 100000)

    feature_flags = build_feature_flags(config)
    num_features = sum(1 for v in feature_flags.values() if v)

    print(f"\nConfiguration:")
    print(f"  n_vertices: {n_vertices}")
    print(f"  k_select: {k_select}")
    print(f"  num_ensembles: {num_ensembles}")
    print(f"  batch_size: {batch_size}")
    print(f"  projection_dim: {projection_dim}")
    print(f"  hidden_dim: {hidden_dim}")
    print(f"  num_hidden: {num_hidden}")
    print(f"  num_features: {num_features}")
    print(f"  postprocess_strategy: {postprocess_strategy}")
    print(f"  n_top_matrices: {n_top_matrices}")
    print(f"  sampling_budget: {sampling_budget}")

    if dry_run:
        print("\n[DRY RUN - not executing]")
        return None

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_dir = OUTPUT_DIR / f"k{k_select}_{timestamp}"
    run_dir.mkdir(parents=True, exist_ok=True)

    TOTAL_TRAINING_BUDGET = 10000
    graphs_per_ensemble = int(TOTAL_TRAINING_BUDGET / num_ensembles)

    unified_config = UnifiedConfig(
        num_ensembles=num_ensembles,
        batch_size=batch_size,
        projection_dim=projection_dim,
        n_vertices=n_vertices,
        hidden_dim=hidden_dim,
        num_hidden=num_hidden,
        graphs_per_ensemble=graphs_per_ensemble,
        num_iterations=5000,
        feature_flags=feature_flags,
        selection_metrics=["all"],
        k_select=k_select,
        selection_objective="average",
        postprocess_strategy=postprocess_strategy,
        n_top_matrices=n_top_matrices,
        sampling_budget=sampling_budget,
        output_dir=str(run_dir),
    )

    print(f"\n{'=' * 40}")
    print("RUNNING PIPELINE")
    print(f"{'=' * 40}")

    results = run_unified_pipeline(unified_config)

    print(f"\n{'=' * 40}")
    print("SAVING RESULTS FOR EACH METRIC")
    print(f"{'=' * 40}")

    all_results = {}

    for metric_name, result in results.items():
        print(f"\n  {metric_name}:")
        print(f"    diversity={result.diversity:.6f}, energy={result.energy:.6f}")

        graphs_path = OUTPUT_DIR / f"diverse_graphs_k{k_select}_{metric_name}.pkl"
        with open(graphs_path, "wb") as f:
            pickle.dump(result.graphs, f)
        print(f"    Saved {len(result.graphs)} graphs to {graphs_path}")

        meta = {
            "n_vertices": n_vertices,
            "k_select": k_select,
            "metric": metric_name,
            "diversity": result.diversity,
            "energy": result.energy,
            "edge_stats": {
                "mean": result.edge_mean,
                "std": result.edge_std,
                "min": result.edge_min,
                "max": result.edge_max,
            },
            "config": config,
            "timestamp": timestamp,
        }
        meta_path = OUTPUT_DIR / f"diverse_graphs_k{k_select}_{metric_name}_meta.json"
        with open(meta_path, "w") as f:
            json.dump(meta, f, indent=2)

        all_results[metric_name] = {
            "diversity": result.diversity,
            "energy": result.energy,
        }

    probs_src = run_dir / f"diverse_probs_{n_vertices}.pkl"
    if probs_src.exists():
        probs_dst = OUTPUT_DIR / f"diverse_probs_k{k_select}.pkl"
        shutil.copy(probs_src, probs_dst)
        print(f"\nSaved probs to {probs_dst}")

    print(f"\n{'=' * 70}")
    print(f"RESULTS SUMMARY (k={k_select})")
    print(f"{'=' * 70}")
    print(f"\n{'Metric':<15} {'Diversity':>12} {'Energy':>12}")
    print("-" * 42)
    for metric_name, scores in all_results.items():
        print(f"{metric_name:<15} {scores['diversity']:>12.6f} {scores['energy']:>12.6f}")

    shutil.rmtree(run_dir)
    print(f"\nCleaned up temporary directory: {run_dir}")

    return all_results


def main():
    parser = argparse.ArgumentParser(description="Run k-scaling experiments")
    parser.add_argument("--config", type=str, help="Path to YAML config file")
    parser.add_argument("--all", action="store_true", help="Run all k configs")
    parser.add_argument("--dry-run", action="store_true", help="Just print config, don't execute")
    args = parser.parse_args()

    if args.all:
        print(f"Running all k-scaling configs: {ALL_K_CONFIGS}")
        for config_path in ALL_K_CONFIGS:
            if Path(config_path).exists():
                run_k_scaling(config_path, dry_run=args.dry_run)
            else:
                print(f"WARNING: Config not found: {config_path}")
    elif args.config:
        run_k_scaling(args.config, dry_run=args.dry_run)
    else:
        parser.print_help()


if __name__ == "__main__":
    main()
