import torch
import random
import numpy as np
import itertools
from datetime import datetime
from train import train_ctqw_model  # Make sure this points to your training script

def set_seed(seed=42):
    """
    Set global random seed for reproducibility.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def run_grid_search():
    """
    Run grid search over multiple hyperparameter combinations for training CTQWformer.
    This function is used to systematically explore the effect of different configurations.
    """
    set_seed(42)  # Ensures consistent results

    # === Search Space Definition ===
    datasets = ['MUTAG']                     # List of datasets to evaluate
    hidden_dims = [64]                       # Hidden dimension for node/graph embeddings
    lrs = [1e-3]                              # Learning rate
    batch_sizes = [1]                         # Batch size per training step
    dropouts = [0.3]                          # Dropout ratio
    heads_list = [4]                          # Number of attention heads
    fusions = ['cat']                         # Fusion strategy: 'cat' or 'add'
    use_attention_bias_list = [True]          # Whether to use CTQW-based structural attention bias
    use_sequence_model_list = [True]          # Whether to use CTQW sequence module (e.g., Bi-GRU)
    num_layers_list = [2]                     # Number of CTQWformer layers
    time_steps_list = [torch.tensor([1.0, 2.0, 3.0, 4.0])]  # Time steps for CTQW evolution

    # === Training Hyperparameters ===
    epochs = 300
    folds = 10  # k-fold cross-validation
    earlystop = 50
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # === Result Logging Path ===
    result_log_path = "results/grid_search_results.csv"

    # Generate all combinations of hyperparameters
    param_grid = list(itertools.product(
        datasets, hidden_dims, lrs, batch_sizes, dropouts, heads_list,
        fusions, use_attention_bias_list, use_sequence_model_list,
        num_layers_list, time_steps_list
    ))

    print(f"Total experiments to run: {len(param_grid)}")

    # === Run Experiments ===
    for config in param_grid:
        (
            dataset_name, hidden_dim, lr, batch_size, dropout, heads,
            fusion, use_attention_bias, use_sequence_model, num_layers, time_steps
        ) = config

        config_str = (
            f"[Dataset={dataset_name} | hidden_dim={hidden_dim} | lr={lr:.0e} | "
            f"batch_size={batch_size} | dropout={dropout} | heads={heads} | fusion={fusion} | "
            f"use_attention_bias={use_attention_bias} | use_sequence_model={use_sequence_model} | "
            f"num_layers={num_layers} | time_steps={time_steps.tolist()}]"
        )

        print(f"\n🚀 Running experiment {config_str}")

        avg_acc = train_ctqw_model(
            dataset_name=dataset_name,
            epochs=epochs,
            folds=folds,
            time_steps=time_steps,
            hidden_dim=hidden_dim,
            lr=lr,
            batch_size=batch_size,
            fusion=fusion,
            heads=heads,
            use_attention_bias=use_attention_bias,
            use_sequence_model=use_sequence_model,
            num_layers=num_layers,
            dropout=dropout,
            earlystop_patience=earlystop,
            device=device,
            result_log_path=result_log_path
        )

        print(f"✅ Finished {config_str} → Avg Acc: {avg_acc:.4f}")

# Main entry point
if __name__ == "__main__":
    run_grid_search()
