import argparse
import yaml
import pickle
import json
from pathlib import Path
from datetime import datetime

from generate_unified import UnifiedConfig, run_unified_pipeline


OUTPUT_BASE = Path("data/large_graphs")
MEMORY_EFFICIENT_THRESHOLD = 256


def load_yaml_config(yaml_path: str) -> dict:
    with open(yaml_path, "r") as f:
        raw_config = yaml.safe_load(f)
    params = raw_config.get("parameters", {})
    config = {}
    for key, val in params.items():
        if isinstance(val, dict) and "value" in val:
            config[key] = val["value"]
        elif isinstance(val, dict) and "values" in val:
            config[key] = val["values"][0]
        else:
            config[key] = val
    return config


def build_feature_flags(config: dict) -> dict[str, bool]:
    feature_flags = {
        "adj_m3_m2": True,
        "adj_m4_m2": True,
        "adj_m5_m3": True,
        "adj_m6_m4": True,
        "adj_m4_m3": True,
        "adj_m6_m2": True,
        "adj_m2_norm": config.get("feat_adj_m2_norm", False),
        "adj_m5_m2": config.get("feat_adj_m5_m2", False),
        "adj_m6_m3": config.get("feat_adj_m6_m3", False),
        "adj_m5_m4": config.get("feat_adj_m5_m4", False),
        "regularity_proxy": config.get("feat_regularity_proxy", False),
        "spectral_spread": config.get("feat_spectral_spread", False),
        "clustering_proxy": config.get("feat_clustering_proxy", False),
        "triangle_density": config.get("feat_triangle_density", False),
        "adj_m3_norm": config.get("feat_adj_m3_norm", False),
        "adj_m4_norm": config.get("feat_adj_m4_norm", False),
        "lap_m2_m1": False,
        "lap_m3_m2": False,
        "lap_m4_m2": False,
        "lap_m4_m3": False,
        "adj_lap_m2": False,
        "adj_lap_m4": False,
        "lap_m2_norm": False,
    }
    return feature_flags


def get_chunk_size(n_vertices: int) -> int:
    bytes_per_graph = n_vertices * n_vertices * 4
    target_chunk_memory = 2 * 1024 * 1024 * 1024
    chunk_size = max(10, target_chunk_memory // bytes_per_graph)
    if n_vertices <= 256:
        return min(chunk_size, 500)
    elif n_vertices <= 512:
        return min(chunk_size, 200)
    elif n_vertices <= 1024:
        return min(chunk_size, 100)
    else:
        return min(chunk_size, 50)


def run_config(
    yaml_path: str,
    use_wandb: bool = False,
    dry_run: bool = False,
    force_memory_efficient: bool | None = None,
):
    print(f"\n{'=' * 70}")
    print(f"LOADING CONFIG: {yaml_path}")
    print(f"{'=' * 70}")

    config = load_yaml_config(yaml_path)

    n_vertices = config["n_vertices"]
    num_ensembles = config["num_ensembles"]
    batch_size = config["batch_size"]
    projection_dim = config["projection_dim"]
    hidden_dim = config["hidden_dim"]
    num_hidden = config["num_hidden"]
    postprocess_strategy = config.get("postprocess_strategy", "iterative_survival")
    n_top_matrices = config.get("n_top_matrices", 1500)
    sampling_budget = config.get("sampling_budget", 100000)

    feature_flags = build_feature_flags(config)
    num_features = sum(1 for v in feature_flags.values() if v)

    if force_memory_efficient is None:
        use_memory_efficient = n_vertices >= MEMORY_EFFICIENT_THRESHOLD
    else:
        use_memory_efficient = force_memory_efficient

    chunk_size = get_chunk_size(n_vertices) if use_memory_efficient else None

    print(f"\nConfiguration:")
    print(f"  n_vertices: {n_vertices}")
    print(f"  num_ensembles: {num_ensembles}")
    print(f"  batch_size: {batch_size}")
    print(f"  projection_dim: {projection_dim}")
    print(f"  hidden_dim: {hidden_dim}")
    print(f"  num_hidden: {num_hidden}")
    print(f"  num_features: {num_features}")
    print(f"  postprocess_strategy: {postprocess_strategy}")
    print(f"  n_top_matrices: {n_top_matrices}")
    print(f"  sampling_budget: {sampling_budget}")
    print(f"  memory_efficient: {use_memory_efficient}")
    if use_memory_efficient:
        print(f"  chunk_size: {chunk_size}")

    bytes_per_graph = n_vertices * n_vertices * 4
    batch_memory_mb = (bytes_per_graph * batch_size * num_ensembles) / (1024 * 1024)
    print(f"\n  Estimated batch memory: {batch_memory_mb:.1f} MB")

    if dry_run:
        print("\n[DRY RUN - not executing]")
        return None

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = OUTPUT_BASE / f"N{n_vertices}_{timestamp}"
    output_dir.mkdir(parents=True, exist_ok=True)

    TOTAL_TRAINING_BUDGET = 10000
    graphs_per_ensemble = int(TOTAL_TRAINING_BUDGET / num_ensembles)

    unified_config = UnifiedConfig(
        num_ensembles=num_ensembles,
        batch_size=batch_size,
        projection_dim=projection_dim,
        n_vertices=n_vertices,
        hidden_dim=hidden_dim,
        num_hidden=num_hidden,
        graphs_per_ensemble=graphs_per_ensemble,
        num_iterations=5000,
        feature_flags=feature_flags,
        selection_metrics=["all"],
        k_select=100,
        selection_objective="average",
        postprocess_strategy=postprocess_strategy,
        n_top_matrices=n_top_matrices,
        sampling_budget=sampling_budget,
        output_dir=str(output_dir),
        memory_efficient=use_memory_efficient,
        chunk_size=chunk_size,
    )

    if use_wandb:
        import wandb
        wandb.init(
            project="diverse-graphs-large",
            name=f"N{n_vertices}_{timestamp}",
            config=config,
        )

    try:
        results = run_unified_pipeline(unified_config)

        print(f"\n{'=' * 70}")
        print("SAVING RESULTS")
        print(f"{'=' * 70}")

        for metric_name, result in results.items():
            graphs_path = output_dir / f"diverse_graphs_{n_vertices}_{metric_name}.pkl"
            with open(graphs_path, "wb") as f:
                pickle.dump(result.graphs, f)
            print(f"  Saved {len(result.graphs)} graphs to {graphs_path}")

        probs_src = output_dir / f"diverse_probs_{n_vertices}.pkl"
        if probs_src.exists():
            probs_dst = OUTPUT_BASE / f"diverse_probs_{n_vertices}.pkl"
            import shutil
            shutil.copy(probs_src, probs_dst)
            print(f"  Copied probs to {probs_dst}")

        if "gcd" in results:
            best_graphs_path = OUTPUT_BASE / f"diverse_graphs_{n_vertices}.pkl"
            with open(best_graphs_path, "wb") as f:
                pickle.dump(results["gcd"].graphs, f)
            print(f"  Saved best (GCD) graphs to {best_graphs_path}")

        meta = {
            "config": config,
            "memory_efficient": use_memory_efficient,
            "chunk_size": chunk_size,
            "results": {
                metric: {
                    "diversity": result.diversity,
                    "energy": result.energy,
                    "edge_mean": result.edge_mean,
                    "edge_std": result.edge_std,
                    "edge_min": result.edge_min,
                    "edge_max": result.edge_max,
                }
                for metric, result in results.items()
            },
            "timestamp": timestamp,
        }
        meta_path = output_dir / "metadata.json"
        with open(meta_path, "w") as f:
            json.dump(meta, f, indent=2)

        print(f"\n{'=' * 70}")
        print("RESULTS SUMMARY")
        print(f"{'=' * 70}")
        for metric_name, result in results.items():
            print(f"  {metric_name}: diversity={result.diversity:.6f}, energy={result.energy:.6f}")

        if use_wandb:
            import wandb
            for metric_name, result in results.items():
                wandb.log({
                    f"{metric_name}_diversity": result.diversity,
                    f"{metric_name}_energy": result.energy,
                })
            wandb.finish()

        return results

    except Exception as e:
        print(f"\nERROR: {e}")
        import traceback
        traceback.print_exc()

        if use_wandb:
            import wandb
            wandb.finish(exit_code=1)

        return None


def main():
    parser = argparse.ArgumentParser(description="Run large graph configurations")
    parser.add_argument("--config", type=str, help="Path to YAML config file")
    parser.add_argument("--all", action="store_true", help="Run all large configs (256, 512, 1024, 2048)")
    parser.add_argument("--wandb", action="store_true", help="Enable wandb logging")
    parser.add_argument("--dry-run", action="store_true", help="Just print config, don't execute")
    parser.add_argument("--no-memory-efficient", action="store_true",
                        help="Force standard mode (no chunking) even for large graphs")
    parser.add_argument("--force-memory-efficient", action="store_true",
                        help="Force memory-efficient mode even for small graphs")
    args = parser.parse_args()

    if args.no_memory_efficient:
        force_memory_efficient = False
    elif args.force_memory_efficient:
        force_memory_efficient = True
    else:
        force_memory_efficient = None

    if args.all:
        configs = [
            "config_256.yaml",
            "config_512.yaml",
            "config_1024.yaml",
            "config_2048.yaml",
        ]
        for config_path in configs:
            if Path(config_path).exists():
                run_config(
                    config_path,
                    use_wandb=args.wandb,
                    dry_run=args.dry_run,
                    force_memory_efficient=force_memory_efficient,
                )
            else:
                print(f"WARNING: Config not found: {config_path}")
    elif args.config:
        run_config(
            args.config,
            use_wandb=args.wandb,
            dry_run=args.dry_run,
            force_memory_efficient=force_memory_efficient,
        )
    else:
        parser.print_help()


if __name__ == "__main__":
    main()
