import argparse
import os
import pickle
import numpy as np
import torch
import torch_geometric
from sklearn.model_selection import KFold
from utils import preprocess, fix_graph
from torch_geometric.datasets import TUDataset
import random
import pandas as pd
from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset, Data
from gnn_models import GCN, GIN_Net, APPNP_Net, GAT, GraphSAGE
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
import torch_geometric.transforms as T
from gnn_utils import train_loop, ogb_train_loop

parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', type=str, help='name of dataset, typically in domain TUDataset')
parser.add_argument('--data_random_seed', type=int, default=123456, help='random seed for data shuffling')
parser.add_argument('--domain_shift_type', type=str, default='graph_size', help='domain shift (size or density)')
parser.add_argument('--domain_shift_order', type=str, default='ascending', help='ascending or descending')
parser.add_argument('--train_val_ratio', type=float, nargs='+', default=[0.6, 0.2], help='train and val ratio')
parser.add_argument('--val_test_setting', type=str, default='seperate', help='seperate or mixed')
parser.add_argument('--baseline_type', type=str, choices=['full', 'random', 'KIDD'], help='the type of baselines')
parser.add_argument('--baseline_random_seed', type=int, default=1234, help='random seed for random data selection')
parser.add_argument('--random_selection_ratios', type=float, nargs='+', help='ratios for random data selection')
parser.add_argument('--GNN_repeat', type=int, default=3, help='number of repeating times for GNN training avging')
parser.add_argument('--GNN_epoch', type=int, default=200, help='number of epochs for GNN traning')
parser.add_argument('--GNN_hidden_dim', type=int, default=32, help='number of hidden dim for GNN traning')
parser.add_argument('--GNN_batch_size', type=int, default=256, help='batch size for GNN traning')
parser.add_argument('--GNN_device', type=str, default="cuda", help='device type for GNN traning')
parser.add_argument('--GNN_model_names', type=str, nargs='+', default=["GCN", "GIN", "GAT", "SAGE"], help='GNN models in use')
# parser.add_argument('--output_dir', type=str, default=".", help='the output directory')
parser.add_argument('--kidd_epochs', type=int, default=30)


# parse args
args = parser.parse_args()
# print all the args
print("================================================")
for arg, value in vars(args).items():
    print(f"{arg}: {value}")
print("================================================")

# get graph data
graph_name = args.dataset_name
if graph_name in ['ogbg-molhiv', 'ogbg-molbace', 'ogbg-molbbbp']:
    dataset = PygGraphPropPredDataset(root='dataset', name=graph_name, transform=fix_graph)
else:
    dataset = TUDataset(root='./datasets', name=graph_name, use_node_attr=True)
print("the dataset is: ", graph_name)
print("number of graphs: ", len(dataset))
avg_graph_size = np.mean([dataset[i].num_nodes for i in range(len(dataset))])
print("avg graph size: ", avg_graph_size)
avg_edge_num = np.mean([dataset[i].edge_index.shape[1] for i in range(len(dataset))])
print("avg edge number: ", avg_edge_num)

if graph_name not in ['ogbg-molhiv', 'ogbg-molbace', 'ogbg-molbbbp']:
    # shuffle the dataset first and pre-process
    # random.seed(args.data_random_seed)
    dataset = list(dataset)
    # dataset = random.sample(dataset, len(dataset))
    dataset = preprocess(dataset)

# fix dataset bugs in ENZYMES dataset
if graph_name =='ENZYMES':
    for graph in dataset:
        if graph.x.shape[0] != graph.num_nodes:
            graph.num_nodes = graph.x.shape[0]


# get label set and label dict
label_list = [dataset[i].y.item() for i in range(len(dataset))]
label_set = set(label_list)
num_classes = len(label_set)
print("label set: ", label_set)
print("num classes: ", num_classes)

# if ogb then num_classes is 1
if graph_name in ['ogbg-molhiv', 'ogbg-molbace', 'ogbg-molbbbp']:
    num_classes = 1


label_dict = {}
for i in range(len(label_list)):
    label = label_list[i]
    if label not in label_dict:
        label_dict[label] = [i]
    else:
        label_dict[label].append(i)



# get data split based on the domain split
if args.domain_shift_type == 'graph_size':
    if args.domain_shift_order == 'ascending':
        data_order = torch.sort(torch.tensor([dataset[i].num_nodes for i in range(len(dataset))]))[1]
    elif args.domain_shift_order == 'descending':
        data_order = torch.sort(torch.tensor([dataset[i].num_nodes for i in range(len(dataset))]))[1].flip(0)
    else:
        raise ValueError("there is no this type of shift order")
elif args.domain_shift_type == 'graph_density':
    if args.domain_shift_order == 'ascending':
        data_order = torch.sort(torch.tensor([dataset[i].edge_index.shape[1] / \
                       (dataset[i].num_nodes * (dataset[i].num_nodes-1)) for i in range(len(dataset))]))[1]
    elif args.domain_shift_order == 'descending':
        data_order = torch.sort(torch.tensor([dataset[i].edge_index.shape[1] / \
                       (dataset[i].num_nodes * (dataset[i].num_nodes-1)) for i in range(len(dataset))]))[1].flip(0)
    else:
        raise ValueError("there is no this type of shift order")
elif args.domain_shift_type == 'no_shift':
    # since the dataset is already shuffled
    # we just keep the same order...
    data_order = torch.arange(len(dataset))
else:
    raise ValueError("there is no this type of this shift type")


# get the train/val/test split based on the defined ratio (fixed!!!)
train_ratio = args.train_val_ratio[0]
val_ratio = args.train_val_ratio[1]
test_ratio = 1.0 - train_ratio - val_ratio
if train_ratio <= 0 or val_ratio <= 0 or test_ratio <= 0:
    raise ValueError("train/val/test ratio is negative")

# if we use seperate val and test
np.random.seed(args.data_random_seed)
if args.val_test_setting == 'seperate':
    train_len = int(len(dataset) * train_ratio)
    val_len = int(len(dataset) * val_ratio)
    train_idx = data_order[:train_len]
    val_idx = data_order[train_len:train_len+val_len]
    test_idx = data_order[train_len+val_len:]
# if we use mixed val and test
elif args.val_test_setting == 'mixed':
    train_len = int(len(dataset) * train_ratio)
    val_len = int(len(dataset) * val_ratio)
    train_idx = data_order[:train_len]
    val_test_idx = data_order[train_len:]
    shuffled_val_test_idx = np.random.permutation(val_test_idx)
    val_idx = shuffled_val_test_idx[:val_len]
    test_idx = shuffled_val_test_idx[val_len:]
else:
    raise ValueError("there is no this type of val/test setting")

print(val_idx[:5], test_idx[:5])
print("biased train idx num: ", len(train_idx))
print("biased val idx num: ", len(val_idx))
print("biased test idx num: ", len(test_idx))

# monitor the label distribution
train_label_recorder = {}
for train_ in train_idx:
    if label_list[train_] not in train_label_recorder:
        train_label_recorder[label_list[train_] ] = 1
    else:
        train_label_recorder[label_list[train_] ] += 1
print("train label distribution: ", train_label_recorder)

val_label_recorder = {}
for val_ in val_idx:
    if label_list[val_] not in val_label_recorder:
        val_label_recorder[label_list[val_] ] = 1
    else:
        val_label_recorder[label_list[val_] ] += 1
print("val label distribution: ", val_label_recorder)

test_label_recorder = {}
for test_ in test_idx:
    if label_list[test_] not in test_label_recorder:
        test_label_recorder[label_list[test_] ] = 1
    else:
        test_label_recorder[label_list[test_] ] += 1
print("test label distribution: ", test_label_recorder)
print("================================================")


# baseline training 
output_dict = {}
# model_names = ['GCN', 'GIN']
# model_names = ['GCN', 'GIN', 'GAT', 'SAGE']
model_names = args.GNN_model_names
print("GNN model considered: ", model_names)

# if we want to get full training 
if args.baseline_type == 'full':
    # to indicate the baseline type
    output_dict[args.baseline_type] = 'full'

    for cur_model_name in model_names:
        if cur_model_name == 'GCN':
            model = GCN(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)
        elif cur_model_name == 'GIN':
            model = GIN_Net(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)
        elif cur_model_name == 'APPNP':
            model = APPNP_Net(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)
        
        
        # add more backbone during rebuttal
        elif cur_model_name == 'GAT':
            model = GAT(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)
        elif cur_model_name == 'SAGE':
            model = GraphSAGE(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)

        # get a placeholder for this GNN model
        output_dict[cur_model_name] = []

        train_dataset = [dataset[j].to(args.GNN_device) for j in train_idx]
        val_dataset = [dataset[j].to(args.GNN_device) for j in val_idx]
        test_dataset = [dataset[j].to(args.GNN_device) for j in test_idx]

        # get data loaders 
        train_loader = DataLoader(train_dataset, batch_size=args.GNN_batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=args.GNN_batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=args.GNN_batch_size, shuffle=False)

        # rerun a few times to compute avg performance
        # record mean and var
        for _ in range(args.GNN_repeat):

            if graph_name in ['ogbg-molhiv', 'ogbg-molbace', 'ogbg-molbbbp']:
                test_dict = ogb_train_loop(model, args.GNN_epoch, train_loader, val_loader, test_loader, 0, args.GNN_device, graph_name)
            else:
                test_dict = train_loop(model, args.GNN_epoch, train_loader, val_loader, test_loader, 0, args.GNN_device)
            perf = test_dict["test_acc_for_best_val"]
            print(f"test acc on best eval model for {cur_model_name}: ", perf)
            
            # add to output
            output_dict[cur_model_name].append(perf)




# random selection with defined ratios
elif args.baseline_type == 'random':

    random.seed(args.baseline_random_seed)
    for cur_model_name in model_names:
        
        if cur_model_name == 'GCN':
            model = GCN(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)
        elif cur_model_name == 'GIN':
            model = GIN_Net(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)
        elif cur_model_name == 'APPNP':
            model = APPNP_Net(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)

        # add more backbone during rebuttal
        elif cur_model_name == 'GAT':
            model = GAT(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)
        elif cur_model_name == 'SAGE':
            model = GraphSAGE(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)

        # get a placeholder for this GNN model
        output_dict[cur_model_name] = {}

        for selection_ratio in args.random_selection_ratios:
            print("=====")
            print("Current selection ratio in train set: ", selection_ratio)
            print("=====")
            output_dict[cur_model_name][selection_ratio] = []
            
            random_selected_idx = random.sample(list(train_idx), int(len(train_idx) * selection_ratio))
            random_train_dataset = [dataset[j].to(args.GNN_device) for j in random_selected_idx]
            val_dataset = [dataset[j].to(args.GNN_device) for j in val_idx]
            test_dataset = [dataset[j].to(args.GNN_device) for j in test_idx]

            # get data loaders 
            random_train_loader = DataLoader(random_train_dataset, batch_size=args.GNN_batch_size, shuffle=True)
            test_loader = DataLoader(test_dataset, batch_size=args.GNN_batch_size, shuffle=False)
            val_loader = DataLoader(val_dataset, batch_size=args.GNN_batch_size, shuffle=False)

            for _ in range(args.GNN_repeat):
                if graph_name in ['ogbg-molhiv', 'ogbg-molbace', 'ogbg-molbbbp']:
                    test_dict = ogb_train_loop(model, args.GNN_epoch, random_train_loader, val_loader, test_loader, 0, args.GNN_device, graph_name)
                else:
                    test_dict = train_loop(model, args.GNN_epoch, random_train_loader, val_loader, test_loader, 0, args.GNN_device)
                perf = test_dict["test_acc_for_best_val"]
                print(f"test acc on best eval model for {cur_model_name}: ",perf)
                print("---------")

                # add to output
                output_dict[cur_model_name][selection_ratio].append(perf)


# use data selected via KIDD-LR
elif args.baseline_type == 'KIDD':
    class KIDDGraphDataset(Dataset):
        def __init__(self, adj_matrices, feature_matrices, labels):

            assert adj_matrices.size(0) == feature_matrices.size(0) == labels.size(0), \
                "The number of graphs must be consistent across adjacency matrices, feature matrices, and labels."
            self.adj_matrices = adj_matrices
            self.feature_matrices = feature_matrices
            self.labels = labels

        def __len__(self):
            return self.adj_matrices.size(0)

        def __getitem__(self, idx):
            adj_matrix = self.adj_matrices[idx]
            feature_matrix = self.feature_matrices[idx]
            label = self.labels[idx]

            # Convert adjacency matrix to edge index format
            edge_index = adj_matrix.nonzero(as_tuple=False).t()

            # Edge attributes (optional, e.g., weights)
            edge_attr = adj_matrix[edge_index[0], edge_index[1]].view(-1)

            # Create PyTorch Geometric Data object
            data = Data(
                x=feature_matrix,  # Node feature matrix
                edge_index=edge_index,  # Edge index in COO format
                edge_attr=edge_attr,  # Edge weights (if any)
                y=label.unsqueeze(0)  # Graph label (as a tensor)
            )
            return data
    
    random.seed(args.baseline_random_seed)
    kidd_dataset_path = f"baseline_methods/KIDD/gda_synthetic_graphs/{graph_name}/{args.domain_shift_type}/"
    
    for cur_model_name in model_names:
        
        if cur_model_name == 'GCN':
            model = GCN(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)
        elif cur_model_name == 'GIN':
            model = GIN_Net(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)
        elif cur_model_name == 'APPNP':
            model = APPNP_Net(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)

        # add more backbone during rebuttal
        elif cur_model_name == 'GAT':
            model = GAT(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)
        elif cur_model_name == 'SAGE':
            model = GraphSAGE(dataset[0].num_node_features, args.GNN_hidden_dim, num_classes).to(args.GNN_device)

        # get a placeholder for this GNN model
        output_dict[cur_model_name] = {}

        for selection_ratio in args.random_selection_ratios:
            print("=====")
            print("Current selection ratio in train set: ", selection_ratio)
            print("=====")
            output_dict[cur_model_name][selection_ratio] = []
            dataset_train_kidd_name = \
                kidd_dataset_path + \
                    f"LowRank_{graph_name}_{selection_ratio}_epoch_{args.kidd_epochs}_{args.domain_shift_type}_{args.domain_shift_order}_{args.val_test_setting}.pt"
            dataset_train_kidd = torch.load(dataset_train_kidd_name)
            kidd_train_dataset = [dataset_train_kidd[j].to(args.GNN_device) for j in range(len(dataset_train_kidd))]
            val_dataset = [dataset[j].to(args.GNN_device) for j in val_idx]
            test_dataset = [dataset[j].to(args.GNN_device) for j in test_idx]

            # get data loaders 
            kidd_train_loader = DataLoader(kidd_train_dataset, batch_size=args.GNN_batch_size, shuffle=True)
            test_loader = DataLoader(test_dataset, batch_size=args.GNN_batch_size, shuffle=False)
            val_loader = DataLoader(val_dataset, batch_size=args.GNN_batch_size, shuffle=False)

            for _ in range(args.GNN_repeat):
                if graph_name in ['ogbg-molhiv', 'ogbg-molbace', 'ogbg-molbbbp']:
                    test_dict = ogb_train_loop(model, args.GNN_epoch, kidd_train_loader, val_loader, test_loader, 0, args.GNN_device, graph_name)
                else:
                    test_dict = train_loop(model, args.GNN_epoch, kidd_train_loader, val_loader, test_loader, 0, args.GNN_device)
                perf = test_dict["test_acc_for_best_val"]
                print(f"test acc on best eval model for {cur_model_name}: ",perf)
                print("---------")

                # add to output
                output_dict[cur_model_name][selection_ratio].append(perf)

# record the dict to pkl, csv and excel

dir_path = f"../new_{args.baseline_type}_script_outputs/{args.dataset_name}"
os.makedirs(dir_path, exist_ok=True)

# save pickle files
pkl_path = dir_path + '/pickle_files'
os.makedirs(pkl_path, exist_ok=True)
pkl_file_name = pkl_path + f"/{args.dataset_name}_{args.domain_shift_type}_{args.domain_shift_order}_{args.val_test_setting}_{args.baseline_type}_epoch_{args.GNN_epoch}.pkl"

with open(pkl_file_name, "wb") as pkl_file:
    pickle.dump(output_dict, pkl_file) 


# get pandas dataframe format
if args.baseline_type == 'full':
    df = pd.DataFrame(output_dict).T
# elif args.baseline_type == 'random':
else:
    df = pd.DataFrame({
    col: {row: output_dict[col][row] for row in output_dict[col].keys()}
    for col in output_dict.keys()
}).T
    df.index.name = 'GNN model'
    
    
# Save to CSV
csv_path = dir_path + '/csv_files'
os.makedirs(csv_path, exist_ok=True)
csv_file_name = csv_path + f"/{args.dataset_name}_{args.domain_shift_type}_{args.domain_shift_order}_{args.val_test_setting}_{args.baseline_type}_epoch_{args.GNN_epoch}.csv"
df.to_csv(csv_file_name, index=True)  # index=False to avoid saving the index as a column

# Save to Excel
excel_path = dir_path + '/excel_files'
os.makedirs(excel_path, exist_ok=True)
excel_file_name = excel_path + f"/{args.dataset_name}_{args.domain_shift_type}_{args.domain_shift_order}_{args.val_test_setting}_{args.baseline_type}_epoch_{args.GNN_epoch}.xlsx"

df.to_excel(excel_file_name, index=True, sheet_name="Sheet1")
