import itertools
import pickle

import numpy as np
import torch
from tqdm import tqdm

from arg_parser import get_args
from model import GRAND_ASC, test, train
from read_data import read_dataset, split_setter
from utils import set_seed

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device is",device,"\n\n")
    args = get_args()

    attention_head = args.attention_head
    learning_rates = args.learning_rates
    weight_decays = args.weight_decays
    hidden_dims = args.hidden_dims
    hidden_layers_list = args.hidden_layers_list
    dropout_rates = args.dropout_rates
    Time = args.Time
    datasets_names = args.datasets
    patience = args.patience
    Monte_Carlo_Iteration = args.monte_carlo_iteration
    iteration = args.learning_iterations

    for dataset_name in datasets_names:
        results = []

        print(f"dataset: {dataset_name}")
        data,num_features,num_classes = read_dataset(dataset_name,device=device)
        relation_type = 'homophilic' if dataset_name in ['Cora','CiteSeer','PubMed'] else 'heterophilic'
        print(f"Dataset: {dataset_name} → Relation type selected: {relation_type}")

        for lr, wd, hd_dim, hd_layers, dropout_rate, Tt in itertools.product(
            learning_rates, weight_decays, hidden_dims, hidden_layers_list, dropout_rates, Time
        ):

            print(f"\nTraining on {dataset_name} with lr={lr}, weight_decay={wd}, hidden_dim={hd_dim}, hidden_layers={hd_layers}, dropout={dropout_rate}", f"Time={Tt}")

            test_accs = []

            for run_monte_carlo in range(Monte_Carlo_Iteration):
                
                data = split_setter(dataset_name,data,split_id = run_monte_carlo,device=device)
                set_seed(seed=42 + run_monte_carlo)
                
                model = GRAND_ASC(
                    in_dim=num_features,
                    hidden_dim=hd_dim,
                    out_dim=num_classes,
                    num_steps=hd_layers,
                    input_dropout=dropout_rate,
                    heads=attention_head,
                    T=Tt,
                    relation_type = relation_type
                ).to(device)

                optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

                best_val_acc = 0.0
                best_test_acc = 0.0
                patience_counter = 0

                for epoch in tqdm(range(iteration)):
                    loss = train(model, optimizer, data)
                    train_acc, val_acc, test_acc = test(model, data)

                    if val_acc > best_val_acc:
                        best_val_acc = val_acc
                        best_test_acc = test_acc
                        patience_counter = 0
                    else:
                        patience_counter += 1

                    if patience_counter >= patience:
                        print(f"Early stopping at epoch {epoch}")
                        break

                test_accs.append(best_test_acc)

            results.append({
                'lr': lr,
                'weight_decay': wd,
                'hidden_dim': hd_dim,
                'hidden_layers': hd_layers,
                'dropout_rate': dropout_rate,
                'test_acc_mean': float(np.mean(test_accs)),
                'test_acc_std': float(np.std(test_accs))
            })

            print(f"Mean test accuracy: {np.mean(test_accs)}")
            print(f"Standard deviation of test accuracy: {np.std(test_accs)}")

        best_res = max(results, key=lambda x: x['test_acc_mean'])

        save_path = f'results_{dataset_name}.pkl'
        with open(save_path, 'wb') as f:
            pickle.dump({'results': results, 'best': best_res}, f)
        print(f"Saved results for {dataset_name} to {save_path}")
        print(f"Best results for {dataset_name}:")
        print(best_res)
