import wandb
import argparse
import sys
import numpy as np
import pickle
import shutil
import json
import filelock
from pathlib import Path

from generate_unified import UnifiedConfig, run_unified_pipeline


BEST_RESULTS_DIR = Path("data/sweep_best")
BEST_TRACKER_FILE = BEST_RESULTS_DIR / "best_tracker.json"
LOCK_FILE = BEST_RESULTS_DIR / ".lock"


def load_best_tracker() -> dict:
    if BEST_TRACKER_FILE.exists():
        with open(BEST_TRACKER_FILE, "r") as f:
            return json.load(f)
    return {}


def save_best_tracker(tracker: dict):
    BEST_RESULTS_DIR.mkdir(parents=True, exist_ok=True)
    with open(BEST_TRACKER_FILE, "w") as f:
        json.dump(tracker, f, indent=2)


def get_best_key(metric: str, n_vertices: int) -> str:
    return f"{metric}_N{n_vertices}"


def update_best_if_improved(
    metric: str,
    n_vertices: int,
    diversity: float,
    graphs: list,
    prob_matrices: np.ndarray | None,
    run_id: str,
    config_dict: dict,
) -> bool:
    BEST_RESULTS_DIR.mkdir(parents=True, exist_ok=True)
    key = get_best_key(metric, n_vertices)
    lock = filelock.FileLock(LOCK_FILE, timeout=60)

    with lock:
        tracker = load_best_tracker()
        current_best = tracker.get(key, {}).get("diversity", -float("inf"))

        if diversity > current_best:
            print(f"\n🎉 NEW BEST for {key}: {diversity:.6f} (was {current_best:.6f})")

            graphs_path = BEST_RESULTS_DIR / f"diverse_graphs_{n_vertices}_{metric}.pkl"
            with open(graphs_path, "wb") as f:
                pickle.dump(graphs, f)
            print(f"  Saved graphs to {graphs_path}")

            if prob_matrices is not None:
                probs_path = BEST_RESULTS_DIR / f"diverse_probs_{n_vertices}.pkl"
                with open(probs_path, "wb") as f:
                    pickle.dump(prob_matrices, f)
                print(f"  Saved probability matrices to {probs_path}")

            meta_path = BEST_RESULTS_DIR / f"diverse_graphs_{n_vertices}_{metric}_meta.json"
            meta = {
                "diversity": diversity,
                "run_id": run_id,
                "config": config_dict,
            }
            with open(meta_path, "w") as f:
                json.dump(meta, f, indent=2)

            tracker[key] = {
                "diversity": diversity,
                "run_id": run_id,
            }
            save_best_tracker(tracker)

            return True
        else:
            print(f"  {key}: {diversity:.6f} (best is {current_best:.6f})")
            return False


def cleanup_run_outputs(output_dir: str):
    output_path = Path(output_dir)
    if output_path.exists():
        shutil.rmtree(output_path)
        print(f"  Cleaned up temporary outputs: {output_dir}")


def build_feature_flags(config: dict) -> dict[str, bool]:
    feature_flags = {
        "adj_m3_m2": True,
        "adj_m4_m2": True,
        "adj_m5_m3": True,
        "adj_m6_m4": True,
        "adj_m4_m3": True,
        "adj_m6_m2": True,
        "adj_m2_norm": config.get("feat_adj_m2_norm", False),
        "adj_m5_m2": config.get("feat_adj_m5_m2", False),
        "adj_m6_m3": config.get("feat_adj_m6_m3", False),
        "adj_m5_m4": config.get("feat_adj_m5_m4", False),
        "regularity_proxy": config.get("feat_regularity_proxy", False),
        "spectral_spread": config.get("feat_spectral_spread", False),
        "clustering_proxy": config.get("feat_clustering_proxy", False),
        "triangle_density": config.get("feat_triangle_density", False),
        "adj_m3_norm": config.get("feat_adj_m3_norm", False),
        "adj_m4_norm": config.get("feat_adj_m4_norm", False),
        "lap_m2_m1": False,
        "lap_m3_m2": False,
        "lap_m4_m2": False,
        "lap_m4_m3": False,
        "adj_lap_m2": False,
        "adj_lap_m4": False,
        "lap_m2_norm": False,
    }
    return feature_flags


def count_enabled_features(feature_flags: dict[str, bool]) -> int:
    return sum(1 for v in feature_flags.values() if v)


def run_sweep_trial():
    run = wandb.init()
    config = wandb.config

    n_vertices = config.n_vertices
    num_ensembles = config.num_ensembles
    batch_size = config.get("batch_size", 50)
    projection_dim = config.get("projection_dim", 4)
    hidden_dim = config.get("hidden_dim", 256)
    num_hidden = config.get("num_hidden", 6)
    num_iterations = config.get("num_iterations", 5000)
    feature_flags = build_feature_flags(dict(config))
    num_features = count_enabled_features(feature_flags)

    postprocess_strategy = config.get("postprocess_strategy", "iterative_survival")
    n_top_matrices = config.get("n_top_matrices", 1000)
    sampling_budget = config.get("sampling_budget", 100000)
    selection_metrics = config.get("selection_metrics", ["all"])
    if isinstance(selection_metrics, str):
        selection_metrics = [selection_metrics]

    training_budget = config.get("training_budget", 10000)
    direction_seed = config.get("direction_seed", 42)

    print(f"\n{'=' * 70}")
    print(f"SWEEP TRIAL")
    print(f"{'=' * 70}")
    print(f"n_vertices: {n_vertices}")
    print(f"num_ensembles: {num_ensembles}")
    print(f"batch_size: {batch_size}")
    print(f"projection_dim: {projection_dim}")
    print(f"hidden_dim: {hidden_dim}")
    print(f"num_hidden: {num_hidden}")
    print(f"num_features: {num_features}")
    print(f"training_budget: {training_budget}")
    print(f"postprocess_strategy: {postprocess_strategy}")
    print(f"n_top_matrices: {n_top_matrices}")
    print(f"sampling_budget: {sampling_budget}")
    print(f"direction_seed: {direction_seed}")
    print()

    wandb.log({
        "n_vertices": n_vertices,
        "direction_seed": direction_seed,
        "num_ensembles": num_ensembles,
        "batch_size": batch_size,
        "projection_dim": projection_dim,
        "hidden_dim": hidden_dim,
        "num_hidden": num_hidden,
        "num_iterations": num_iterations,
        "num_features": num_features,
        "training_budget": training_budget,
        "postprocess_strategy": postprocess_strategy,
        "n_top_matrices": n_top_matrices,
        "sampling_budget": sampling_budget,
    })

    graphs_per_ensemble = training_budget / num_ensembles

    samples_needed = int(np.ceil(graphs_per_ensemble / batch_size))
    sample_interval = max(1, num_iterations // samples_needed)
    expected_collections = (num_iterations // sample_interval) + 1
    expected_total = (
        min(expected_collections * batch_size, graphs_per_ensemble) * num_ensembles
    )

    print(f"\nCollection budget check:")
    print(f"  training_budget: {training_budget}")
    print(f"  graphs_per_ensemble: {graphs_per_ensemble}")
    print(f"  expected_total: {expected_total}")

    if expected_total < training_budget * 0.9:
        print(f"  WARNING: May under-collect! Expected {expected_total} < {training_budget}")

    output_dir = f"data/sweep_{run.id}"

    unified_config = UnifiedConfig(
        num_ensembles=num_ensembles,
        batch_size=batch_size,
        projection_dim=projection_dim,
        n_vertices=n_vertices,
        hidden_dim=hidden_dim,
        num_hidden=num_hidden,
        graphs_per_ensemble=int(graphs_per_ensemble),
        num_iterations=num_iterations,
        feature_flags=feature_flags,
        selection_metrics=selection_metrics,
        k_select=100,
        selection_objective="average",
        postprocess_strategy=postprocess_strategy,
        n_top_matrices=n_top_matrices,
        sampling_budget=sampling_budget,
        output_dir=output_dir,
        direction_seed=direction_seed,
    )

    config_dict = {
        "n_vertices": n_vertices,
        "num_ensembles": num_ensembles,
        "direction_seed": direction_seed,
        "batch_size": batch_size,
        "projection_dim": projection_dim,
        "hidden_dim": hidden_dim,
        "num_hidden": num_hidden,
        "num_iterations": num_iterations,
        "training_budget": training_budget,
        "postprocess_strategy": postprocess_strategy,
        "n_top_matrices": n_top_matrices,
        "sampling_budget": sampling_budget,
        "feature_flags": {k: v for k, v in feature_flags.items() if v},
    }

    prob_matrices = None

    try:
        results = run_unified_pipeline(unified_config)

        probs_path = Path(output_dir) / f"diverse_probs_{n_vertices}.pkl"
        if probs_path.exists():
            with open(probs_path, "rb") as f:
                prob_matrices = pickle.load(f)
            print(f"Loaded {len(prob_matrices)} probability matrices for best tracking")

        log_dict = {}
        for metric_name, result in results.items():
            log_dict[f"{metric_name}_diversity"] = result.diversity
            log_dict[f"{metric_name}_energy"] = result.energy
            log_dict[f"{metric_name}_pool_size"] = result.pool_size
            log_dict[f"{metric_name}_edge_mean"] = result.edge_mean
            log_dict[f"{metric_name}_edge_std"] = result.edge_std

        wandb.log(log_dict)

        print(f"\n{'=' * 70}")
        print("SWEEP TRIAL RESULTS")
        print(f"{'=' * 70}")
        for metric_name, result in results.items():
            print(f"{metric_name}: diversity={result.diversity:.6f}, energy={result.energy:.6f}")

        print(f"\n{'=' * 70}")
        print("CHECKING FOR NEW BESTS")
        print(f"{'=' * 70}")

        any_new_best = False
        for metric_name, result in results.items():
            is_best = update_best_if_improved(
                metric=metric_name,
                n_vertices=n_vertices,
                diversity=result.diversity,
                graphs=result.graphs,
                prob_matrices=prob_matrices,
                run_id=run.id,
                config_dict=config_dict,
            )
            if is_best:
                any_new_best = True

        print(f"\n{'=' * 70}")
        print("CLEANUP")
        print(f"{'=' * 70}")
        cleanup_run_outputs(output_dir)

    except Exception as e:
        print(f"ERROR during trial: {e}")
        import traceback
        traceback.print_exc()

        wandb.log({
            "gcd_diversity": 0.0,
            "netlsd_heat_diversity": 0.0,
            "netlsd_wave_diversity": 0.0,
            "portrait_div_diversity": 0.0,
            "error": str(e),
        })

        cleanup_run_outputs(output_dir)

    finally:
        wandb.finish()


def run_test_trial(n_vertices: int = 16, training_budget: int = 10000):
    print(f"Running test trial for N={n_vertices}, training_budget={training_budget} (no wandb)...")

    num_ensembles = 20
    batch_size = 50

    feature_flags = {
        "adj_m3_m2": True,
        "adj_m4_m2": True,
        "adj_m5_m3": True,
        "adj_m6_m4": True,
        "adj_m4_m3": True,
        "adj_m6_m2": True,
        "adj_m2_norm": True,
        "adj_m5_m2": False,
        "adj_m6_m3": False,
        "adj_m5_m4": False,
        "regularity_proxy": True,
        "spectral_spread": False,
        "clustering_proxy": True,
        "triangle_density": False,
        "adj_m3_norm": True,
        "adj_m4_norm": False,
        "lap_m2_m1": False,
        "lap_m3_m2": False,
        "lap_m4_m2": False,
        "lap_m4_m3": False,
        "adj_lap_m2": False,
        "adj_lap_m4": False,
        "lap_m2_norm": False,
    }

    graphs_per_ensemble = training_budget / num_ensembles

    unified_config = UnifiedConfig(
        num_ensembles=num_ensembles,
        batch_size=batch_size,
        projection_dim=4,
        n_vertices=n_vertices,
        graphs_per_ensemble=int(graphs_per_ensemble),
        num_iterations=500,
        feature_flags=feature_flags,
        selection_metrics=["gcd"],
        k_select=50,
        selection_objective="average",
        postprocess_strategy="iterative_survival",
        n_top_matrices=10_000,
        sampling_budget=100_000,
        output_dir="data/sweep_test",
    )

    results = run_unified_pipeline(unified_config)

    print("\nTest results:")
    for metric_name, result in results.items():
        print(f"  {metric_name}: diversity={result.diversity:.6f}")

    cleanup_run_outputs("data/sweep_test")


def main():
    parser = argparse.ArgumentParser(description="WandB Sweep Runner")
    parser.add_argument("--test", action="store_true", help="Run a single test trial without wandb")
    parser.add_argument("--n_vertices", type=int, default=16, help="Number of vertices for test trial")
    parser.add_argument("--training_budget", type=int, default=10000, help="Training budget for test trial")
    args = parser.parse_args()

    if args.test:
        run_test_trial(args.n_vertices, args.training_budget)
    else:
        run_sweep_trial()


if __name__ == "__main__":
    main()
