# src/main.py

import torch
import optuna
from src.lfp_dataset import prepare_all_datasets
from src.objective import objective
from src.train import evaluate_best_model
from src.utils import set_seed
from src.dataset_simulated_setting2 import SyntheticMultimodalGraphDataset
from src.dataset_simulated_setting1 import load_simulated_datasets

def main():
    # Step 0: Set seed and device
    set_seed(1)
    device = torch.device("cpu")
    print(f"🖥️  Using device: {device}")

    # Step 1: Load and prepare balanced GNN datasets from 5 rats
    print("📦 Loading LFP data and preparing balanced GNN datasets...")
    gnn_datasets = prepare_all_datasets(data_dir="lfp_data")

    # for loading simulation setting 1 datasets:
    #gnn_datasets = load_simulated_datasets()
    
    # for loading simulation setting 2 datasets:
    #dataset_generator = SyntheticMultimodalGraphDataset()
    #gnn_datasets = dataset_generator.get_all_modalities()
    
    # Step 2: Run Optuna hyperparameter optimization
    print("🔍 Running Optuna for hyperparameter tuning...")
    study = optuna.create_study(direction="maximize", sampler=optuna.samplers.TPESampler(seed=1))
    study.optimize(lambda trial: objective(trial, gnn_datasets, device), n_trials=100)

    # Step 3: Show best hyperparameters
    print("\n✅ Best hyperparameters:")
    best_params = study.best_trial.params
    for k, v in best_params.items():
        print(f"{k}: {v}")

    # Step 4: Evaluate final model on held-out test fold
    print("\n🧪 Evaluating best model on held-out test fold...")
    test_acc, preds, labels = evaluate_best_model(best_params, gnn_datasets, device)

    print(f"\n🎯 Final Test Accuracy: {test_acc:.4f}")
    # Optional: Save predictions or results
    # torch.save(preds, "results/test_preds.pt")

if __name__ == "__main__":
    main()
