import argparse
import os
import pickle
import numpy as np
import torch
import torch_geometric
from sklearn.model_selection import KFold
from utils import preprocess, our_selection_method_wrap, fix_graph
from torch_geometric.datasets import TUDataset
import random
import pandas as pd
from torch_geometric.loader import DataLoader
from gnn_models import GCN, GIN_Net, APPNP_Net, GAT, GraphSAGE
from gnn_utils import train_loop, ogb_train_loop
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
import torch_geometric.transforms as T

parser = argparse.ArgumentParser()

### data specific ###
parser.add_argument('--dataset_name', type=str, help='name of dataset, typically in domain TUDataset')
parser.add_argument('--data_random_seed', type=int, default=123456, help='random seed for data shuffling')
parser.add_argument('--domain_shift_type', type=str, default='graph_size', help='domain shift (size or density or no shift)')
parser.add_argument('--domain_shift_order', type=str, default='ascending', help='ascending or descending')
parser.add_argument('--train_val_ratio', type=float, nargs='+', default=[0.6, 0.2], help='train and val ratio')
parser.add_argument('--val_test_setting', type=str, default='seperate', help='seperate or mixed')
### data specific ###

### method specific ###
# parser.add_argument('--running_seed', type=int, default=1234, help='seed for GNN running')

parser.add_argument('--our_method_type', type=str, choices=['ver1', 'lava'], help='the type of the method')
parser.add_argument('--alphas', type=float, nargs='+', default=[0.5,0.9], help='the LinearFGW ratios to mix feature and structure')
parser.add_argument('--label_signals', type=float, nargs='+', default=[0,5], help='the label signals to compute dataset distance')
parser.add_argument('--barycenter_sizes', type=int, nargs='+', default=[5], help='the sizes of barycenters used in LinearFGW computation')
parser.add_argument('--linearfgw_device', type=str, default='cpu', help='the device used in LinearFGW computation')
parser.add_argument('--recompute_weight', type=int, default=1, help='whether to re-compute the training weights')
parser.add_argument('--pre_compute_weight_path', type=str, default='', help='if no re-computation, then load weights from path')
parser.add_argument('--update_steps', type=int, default=10, help='the number of steps for dataset distance minimization')

parser.add_argument('--kmeans_device', type=str, default='cuda')

parser.add_argument('--alphas_to_use', type=float, nargs='+', help='the LinearFGW ratios to use')
parser.add_argument('--label_signals_to_use', type=float, nargs='+', help='the label signals to use')

parser.add_argument('--selection_ratios', type=float, nargs='+', help='ratios for our data selection')
### method specific ###

### GNN specific ###
parser.add_argument('--GNN_repeat', type=int, default=3, help='number of repeating times for GNN training avging')
parser.add_argument('--GNN_epoch', type=int, default=200, help='number of epochs for GNN traning')
parser.add_argument('--GNN_hidden_dim', type=int, default=32, help='number of hidden dim for GNN traning')
parser.add_argument('--GNN_batch_size', type=int, default=1024, help='batch size for GNN traning')
parser.add_argument('--GNN_device', type=str, default="cuda", help='device type for GNN traning')
parser.add_argument('--GNN_model_names', type=str, nargs='+', default=["GCN", "GIN", "GAT", "SAGE"], help='GNN models in use')
### GNN specific ###

# parse args
args = parser.parse_args()
# print all the args
print("================================================")
for arg, value in vars(args).items():
    print(f"{arg}: {value}")
print("================================================")

# get graph data
graph_name = args.dataset_name
if graph_name in ["ogbg-molhiv", 'ogbg-molbace', 'ogbg-molbbbp']:
    dataset = PygGraphPropPredDataset(root='dataset', name=graph_name, transform=fix_graph)

else:
    dataset = TUDataset(root='./datasets', name=graph_name, use_node_attr=True)
print("the dataset is: ", graph_name)
print("number of graphs: ", len(dataset))
avg_graph_size = np.mean([dataset[i].num_nodes for i in range(len(dataset))])
print("avg graph size: ", avg_graph_size)
avg_edge_num = np.mean([dataset[i].edge_index.shape[1] for i in range(len(dataset))])
print("avg edge number: ", avg_edge_num)

if graph_name not in  ['ogbg-molhiv', 'ogbg-molbace', 'ogbg-molbbbp']:
    # shuffle the dataset first and pre-process
    # random.seed(args.data_random_seed)
    dataset = list(dataset)
    # dataset = random.sample(dataset, len(dataset))
    dataset = preprocess(dataset)

# fix dataset bugs in ENZYMES dataset
if graph_name =='ENZYMES':
    for graph in dataset:
        if graph.x.shape[0] != graph.num_nodes:
            graph.num_nodes = graph.x.shape[0]


# get label set and label dict
label_list = [dataset[i].y.item() for i in range(len(dataset))]
label_set = set(label_list)
num_classes = len(label_set)
print("label set: ", label_set)
print("num classes: ", num_classes)

# if ogb then num_classes is 1
if graph_name in ['ogbg-molhiv', 'ogbg-molbace', 'ogbg-molbbbp']:
    num_classes = 1

label_dict = {}
for i in range(len(label_list)):
    label = label_list[i]
    if label not in label_dict:
        label_dict[label] = [i]
    else:
        label_dict[label].append(i)



# get data split based on the domain split
if args.domain_shift_type == 'graph_size':
    if args.domain_shift_order == 'ascending':
        data_order = torch.sort(torch.tensor([dataset[i].num_nodes for i in range(len(dataset))]))[1]
    elif args.domain_shift_order == 'descending':
        data_order = torch.sort(torch.tensor([dataset[i].num_nodes for i in range(len(dataset))]))[1].flip(0)
    else:
        raise ValueError("there is no this type of shift order")
elif args.domain_shift_type == 'graph_density':
    if args.domain_shift_order == 'ascending':
        data_order = torch.sort(torch.tensor([dataset[i].edge_index.shape[1] / \
                       (dataset[i].num_nodes * (dataset[i].num_nodes-1)) for i in range(len(dataset))]))[1]
    elif args.domain_shift_order == 'descending':
        data_order = torch.sort(torch.tensor([dataset[i].edge_index.shape[1] / \
                       (dataset[i].num_nodes * (dataset[i].num_nodes-1)) for i in range(len(dataset))]))[1].flip(0)
    else:
        raise ValueError("there is no this type of shift order")
elif args.domain_shift_type == 'no_shift':
    # since the dataset is already shuffled
    # we just keep the same order...
    data_order = torch.arange(len(dataset))
else:
    raise ValueError("there is no this type of this shift type")


# get the train/val/test split based on the defined ratio (fixed!!!)
train_ratio = args.train_val_ratio[0]
val_ratio = args.train_val_ratio[1]
test_ratio = 1.0 - train_ratio - val_ratio
if train_ratio <= 0 or val_ratio <= 0 or test_ratio <= 0:
    raise ValueError("train/val/test ratio is negative")

# if we use seperate val and test
np.random.seed(args.data_random_seed)
if args.val_test_setting == 'seperate':
    train_len = int(len(dataset) * train_ratio)
    val_len = int(len(dataset) * val_ratio)
    train_idx = data_order[:train_len]
    val_idx = data_order[train_len:train_len+val_len]
    test_idx = data_order[train_len+val_len:]
# if we use mixed val and test
elif args.val_test_setting == 'mixed':
    train_len = int(len(dataset) * train_ratio)
    val_len = int(len(dataset) * val_ratio)
    train_idx = data_order[:train_len]
    val_test_idx = data_order[train_len:]
    shuffled_val_test_idx = np.random.permutation(val_test_idx)
    val_idx = shuffled_val_test_idx[:val_len]
    test_idx = shuffled_val_test_idx[val_len:]
else:
    raise ValueError("there is no this type of val/test setting")

print("biased train idx num: ", len(train_idx))
print("biased val idx num: ", len(val_idx))
print("biased test idx num: ", len(test_idx))

# monitor the label distribution
train_label_recorder = {}
for train_ in train_idx:
    if label_list[train_] not in train_label_recorder:
        train_label_recorder[label_list[train_] ] = 1
    else:
        train_label_recorder[label_list[train_] ] += 1
print("train label distribution: ", train_label_recorder)

val_label_recorder = {}
for val_ in val_idx:
    if label_list[val_] not in val_label_recorder:
        val_label_recorder[label_list[val_] ] = 1
    else:
        val_label_recorder[label_list[val_] ] += 1
print("val label distribution: ", val_label_recorder)

test_label_recorder = {}
for test_ in test_idx:
    if label_list[test_] not in test_label_recorder:
        test_label_recorder[label_list[test_] ] = 1
    else:
        test_label_recorder[label_list[test_] ] += 1
print("test label distribution: ", test_label_recorder)
print("================================================")


# baseline training 
output_dict = {}
# model_names = ['GCN', 'GIN']
# add backbone during rebuttal
# model_names = ['GCN', 'GIN', 'GAT', 'SAGE']
model_names = args.GNN_model_names
print("GNN model considered: ", model_names)


# our method's selection with defined ratios

# 1. whether to re-compute weight
# will compute many combinations of hyper params
# store to pre-defined location
if args.our_method_type in ["ver1", "ver2"]:
    save_weight_dir_path = f"our_methods/{args.dataset_name}"
    os.makedirs(save_weight_dir_path, exist_ok=True)
elif args.our_method_type == 'lava':
    save_weight_dir_path = f"lava_methods/{args.dataset_name}"
    os.makedirs(save_weight_dir_path, exist_ok=True)
elif args.our_method_type == 'kmed':
    save_weight_dir_path = f"kmed_methods/{args.dataset_name}"
    os.makedirs(save_weight_dir_path, exist_ok=True)
# save to pkl file
weight_file_dir_path = save_weight_dir_path + "/pkl_files"
os.makedirs(weight_file_dir_path, exist_ok=True)

if args.recompute_weight:
    # recompute the weights no matter how
    new_weight_dict = our_selection_method_wrap(args,
                                                dataset,
                                                graph_name,
                                                label_dict,
                                                label_list,
                                                train_idx,
                                                val_idx,
                                                )
    
    # give pkl file name
    pkl_file_name = weight_file_dir_path + f"/{args.dataset_name}_{args.domain_shift_type}_{args.domain_shift_order}_{args.val_test_setting}_{args.our_method_type}_epoch_{args.GNN_epoch}.pkl"

    with open(pkl_file_name, 'wb') as file:
        pickle.dump(new_weight_dict, file)
else:
    if args.pre_compute_weight_path:
        # load pickle file (a huge dict basically)
        if args.dataset_name == "ENZYMES":
            default_epoch = 400
        elif args.dataset_name in ["ogbg-molhiv", "ogbg-molbace", 'ogbg-molbbbp']:
            default_epoch = 100
        else:
            default_epoch = 200
        pkl_file_name = weight_file_dir_path + f"/{args.dataset_name}_{args.domain_shift_type}_{args.domain_shift_order}_{args.val_test_setting}_{args.our_method_type}_epoch_{default_epoch}.pkl"

        with open(pkl_file_name, 'rb') as file:
            # pickle.dump(new_weight_dict, file)
            new_weight_dict = pickle.load(file)

    else:
        raise ValueError("not recomputing nor giving precomputed weights")


# 2. decide which set of hyper params to get selected train index
# will extract some specific settings from the weight dict
# currently we have alpha and label signal as hyper params


# 3. use the determined index to train GNNs
# if we didn't specifcy alpha and label signal, then use what we compute
if not args.alphas_to_use:
    args.alphas_to_use = args.alphas
if not args.label_signals_to_use:
    args.label_signals_to_use = args.label_signals

print("alpha to use: ", args.alphas_to_use)
print("label signals to use: ", args.label_signals_to_use)

for alpha_in_use in args.alphas_to_use:
    for label_signal_in_use in args.label_signals_to_use:
        output_dict[(alpha_in_use, label_signal_in_use)] = {}
        # torch.manual_seed(args.running_seed)
        for cur_model_name in model_names:
            
            if cur_model_name == 'GCN':
                model = GCN(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)
            elif cur_model_name == 'GIN':
                model = GIN_Net(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)
            elif cur_model_name == 'APPNP':
                model = APPNP_Net(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)

            # add more backbone during rebuttal
            elif cur_model_name == 'GAT':
                model = GAT(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)
            elif cur_model_name == 'SAGE':
                model = GraphSAGE(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)


            # get a placeholder for this GNN model
            output_dict[(alpha_in_use, label_signal_in_use)][cur_model_name] = {}

            for selection_ratio in args.selection_ratios:
                output_dict[(alpha_in_use, label_signal_in_use)][cur_model_name][selection_ratio] = []
                print("=====")
                print("Current selection ratio in train set: ", selection_ratio)
                print("=====")
                # get the weights under this combination of the hyper params
                our_selected_idx = new_weight_dict[label_signal_in_use][selection_ratio][alpha_in_use]
                if isinstance(our_selected_idx, torch.Tensor):
                    our_selected_idx = our_selected_idx.tolist()
                # print(our_selected_idx, len(our_selected_idx))
                our_train_dataset = [dataset[j].to(args.GNN_device) for j in our_selected_idx]
                val_dataset = [dataset[j].to(args.GNN_device) for j in val_idx]
                test_dataset = [dataset[j].to(args.GNN_device) for j in test_idx]

                # get data loaders 
                our_train_loader = DataLoader(our_train_dataset, batch_size=args.GNN_batch_size, shuffle=True)
                test_loader = DataLoader(test_dataset, batch_size=args.GNN_batch_size, shuffle=False)
                val_loader = DataLoader(val_dataset, batch_size=args.GNN_batch_size, shuffle=False)

                for _ in range(args.GNN_repeat):
                    if graph_name in ['ogbg-molhiv', 'ogbg-molbace', 'ogbg-molbbbp']:
                        test_dict = ogb_train_loop(model, args.GNN_epoch, our_train_loader, val_loader, test_loader, 0, args.GNN_device, graph_name)
                    else:
                        test_dict = train_loop(model, args.GNN_epoch, our_train_loader, val_loader, test_loader, 0, args.GNN_device)
                    perf = test_dict["test_acc_for_best_val"]
                    print(f"test acc on best eval model for {cur_model_name}: ",perf)
                    print("---------")

                    # add to output
                    output_dict[(alpha_in_use, label_signal_in_use)][cur_model_name][selection_ratio].append(perf)



# record the dict to pkl, csv and excel
if args.our_method_type in ['ver1', 'ver2']:
    dir_path = f"../new_ours_script_outputs/{args.dataset_name}"
    os.makedirs(dir_path, exist_ok=True)
elif args.our_method_type in ['lava']:
    dir_path = f"../new_lava_script_outputs/{args.dataset_name}"
    os.makedirs(dir_path, exist_ok=True)


# save pickle files
pkl_path = dir_path + '/pickle_files'
os.makedirs(pkl_path, exist_ok=True)
pkl_file_name = pkl_path + f"/{args.dataset_name}_{args.domain_shift_type}_{args.domain_shift_order}_{args.val_test_setting}_{args.our_method_type}_epoch_{args.GNN_epoch}.pkl"

with open(pkl_file_name, "wb") as pkl_file:
    pickle.dump(output_dict, pkl_file) 


# get pandas dataframe format


# df = pd.DataFrame({
#     (col, sublayer): {row: output_dict[col][row][sublayer] 
#                       for row in output_dict[col].keys()}
#     for col in output_dict.keys()
#     for sublayer in next(iter(output_dict[col].values())).keys()
# })

df = pd.DataFrame({
    (col, sublayer): {row: output_dict[col][row].get(sublayer, None)
                      for row in output_dict[col].keys()}
    for col in output_dict.keys()
    for sublayer in next(iter(output_dict[col].values())).keys()
})


# Add names to the column levels
df.columns = pd.MultiIndex.from_tuples(df.columns, names=["alpha/label signal", "selection ratio"])

df.index.name = 'GNN model'
    
# Save to CSV
csv_path = dir_path + '/csv_files'
os.makedirs(csv_path, exist_ok=True)
csv_file_name = csv_path + f"/{args.dataset_name}_{args.domain_shift_type}_{args.domain_shift_order}_{args.val_test_setting}_{args.our_method_type}_epoch_{args.GNN_epoch}.csv"
df.to_csv(csv_file_name, index=True)  # index=False to avoid saving the index as a column

# Save to Excel
excel_path = dir_path + '/excel_files'
os.makedirs(excel_path, exist_ok=True)
excel_file_name = excel_path + f"/{args.dataset_name}_{args.domain_shift_type}_{args.domain_shift_order}_{args.val_test_setting}_{args.our_method_type}_epoch_{args.GNN_epoch}.xlsx"

df.to_excel(excel_file_name, index=True, sheet_name="Sheet1")
