import os
import torch
import numpy as np
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
import matplotlib.pyplot as plt
import argparse as argparse
from torch_geometric.data import Data
from utils import *
import pickle
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')



parser = argparse.ArgumentParser(description='All feature analysis')
parser.add_argument('--dataset', type=str, default='Cora', help='Dataset name (default: Cora) or dataSim_Y_A,X_A,Y_X for synthetic data')
parser.add_argument('--N_iter', type=int, default=5, help='Number of iterations for averaging results (default: 5)')
args = parser.parse_args()

DS = args.dataset



N_iter = args.N_iter  # number of iterations for averaging results
if DS in ['Computers', 'Photo']:
    epochs = 800
    verbos_every = 100
else:
    epochs = 400
    verbos_every = 50

if DS in ['roman_empire', 'amazon_ratings', 'minesweeper', 'tolokers', 'questions',  'Texas', 'Cornell', 'Wisconsin', 'Actor']:
    ishet = True
else: 
    ishet = False
# 1) GCN on original data
print("\n== GNN (real features + real graph) ==")
best_val_or = []
best_test_at_val_or = []

best_val_mlp = []
best_test_at_val_mlp = []

best_val_pt = []
best_test_at_val_pt = []

best_val_rand = []
best_test_at_val_rand = []

best_val_perm_graph = []
best_test_at_val_perm_graph = []

for iter in range(N_iter):
    print(f"\nIteration {iter+1}/{N_iter}")
    set_seed_all(iter)
    dataset, data, num_node_features, num_classes = load_dataset(DS, device)
    data = data.to(device)
    if ishet:
        model_or = GCNNodeClassifier_Het(num_node_features, [512, 512], num_classes).to(device)
    else:
        if DS in ['Photo','Computers']:
            model_or = GIN(in_dim=num_node_features, hid_dim=512, out_dim=num_classes).to(device)
        else:
            model_or = GCNNodeClassifier(num_node_features, [512, 512], num_classes).to(device)    
    optimizer_or = torch.optim.Adam(model_or.parameters(), lr=0.01, weight_decay=5e-4)
    best_val_or_iter, best_test_at_val_or_iter, _ = run_experiment(model_or, optimizer_or, data, epochs=epochs, verbose_every=verbos_every)
    best_val_or.append(best_val_or_iter)
    best_test_at_val_or.append(best_test_at_val_or_iter)

    set_seed_all(iter)
    mlp = MLPNodeClassifier(num_node_features, [512, 512], num_classes).to(device)
    optimizer_mlp = torch.optim.Adam(mlp.parameters(), lr=0.01, weight_decay=5e-4)
    best_val_mlp_iter, best_test_at_val_mlp_iter, _ = run_experiment(mlp, optimizer_mlp, data, epochs=epochs, verbose_every=verbos_every)
    best_val_mlp.append(best_val_mlp_iter)
    best_test_at_val_mlp.append(best_test_at_val_mlp_iter)

    set_seed_all(iter)
    data_perm = permute_features(data, seed=0)
    if ishet:
        gcn_perm = GCNNodeClassifier_Het(num_node_features, [512, 512], num_classes).to(device)
    else:
        if DS in ['Photo','Computers']:
            gcn_perm = GIN(in_dim=num_node_features, hid_dim=512, out_dim=num_classes).to(device)
        else:
            gcn_perm = GCNNodeClassifier(num_node_features, [512, 512], num_classes).to(device)        # reset_model_parameters(gcn_perm)
    optimizer_gcn_perm = torch.optim.Adam(gcn_perm.parameters(), lr=0.01, weight_decay=5e-4)
    best_val_pt_iter, best_test_at_val_pt_iter, _ = run_experiment(gcn_perm, optimizer_gcn_perm, data_perm, epochs=epochs, verbose_every=verbos_every)
    best_val_pt.append(best_val_pt_iter)
    best_test_at_val_pt.append(best_test_at_val_pt_iter)

    set_seed_all(iter)
    data = data.to(device)
    data_rand = data.clone()
    data_rand.x = torch.randn_like(data.x)  # replace with random features
    if ishet:
        gcn_rand = GCNNodeClassifier_Het(num_node_features, [512, 512], num_classes).to(device)
    else:
        if DS in ['Photo','Computers']:
            gcn_rand = GIN(in_dim=num_node_features, hid_dim=512, out_dim=num_classes).to(device)
        else:
            gcn_rand = GCNNodeClassifier(num_node_features, [512, 512], num_classes).to(device) 
    # reset_model_parameters(gcn_rand)
    optimizer_gcn_rand = torch.optim.Adam(gcn_rand.parameters(), lr=0.01, weight_decay=5e-4)
    best_val_rand_iter, best_test_at_val_rand_iter, _ = run_experiment(gcn_rand, optimizer_gcn_rand, data_rand, epochs=epochs, verbose_every=verbos_every)
    best_val_rand.append(best_val_rand_iter)
    best_test_at_val_rand.append(best_test_at_val_rand_iter)

    set_seed_all(iter)
    data_perm_graph = data.clone()
    data_perm_graph.edge_index = permute_graph(data.edge_index, seed=iter)
    if ishet:
        gcn_perm_graph = GCNNodeClassifier_Het(num_node_features, [512, 512], num_classes).to(device)
    else:
        if DS in ['Photo','Computers']:
            gcn_perm_graph = GIN(in_dim=num_node_features, hid_dim=512, out_dim=num_classes).to(device)
        else:
            gcn_perm_graph = GCNNodeClassifier(num_node_features, [512, 512], num_classes).to(device) 
    optimizer_gcn_perm_graph = torch.optim.Adam(gcn_perm_graph.parameters(), lr=0.01, weight_decay=5e-4)
    best_val_perm_graph_iter, best_test_at_val_perm_graph_iter, _ = run_experiment(gcn_perm_graph, optimizer_gcn_perm_graph, data_perm_graph, epochs=epochs, verbose_every=verbos_every)
    best_val_perm_graph.append(best_val_perm_graph_iter)
    best_test_at_val_perm_graph.append(best_test_at_val_perm_graph_iter)





# print final results
print("\nFinal Results:")
print(f"GCN (Original graph): Test {np.mean(best_test_at_val_or):.2f} ± {np.std(best_test_at_val_or):.2f}")
print(f"MLP (features only, no graph): Test {np.mean(best_test_at_val_mlp):.2f} ± {np.std(best_test_at_val_mlp):.2f}")
print(f"GCN (permuted features, original graph): Test {np.mean(best_test_at_val_pt):.2f} ± {np.std(best_test_at_val_pt):.2f}")
print(f"GCN (random features, original graph): Test {np.mean(best_test_at_val_rand):.2f} ± {np.std(best_test_at_val_rand):.2f}")
print(f"GCN (original features, ER random graph): Test {np.mean(best_test_at_val_perm_graph):.2f} ± {np.std(best_test_at_val_perm_graph):.2f}")




