import argparse
import json
import time
import dgl
import os
import numpy as np
import torch
from torch_geometric.data import Batch

import datautils
from config import *


parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=2025, help='seed')
parser.add_argument('--data', type=str, default=MCF7, help='data')
parser.add_argument('--trainsz', type=float, default=0.2, help='val size')
parser.add_argument('--testsz', type=float, default=0.6, help='test size')
parser.add_argument('--preprocess', type=int, default=0, help='whether preprocess')
parser.add_argument('--split', type=int, default=0, help='whether split')

args = parser.parse_args()

seed = args.seed
data = args.data
trainsz = args.trainsz
testsz = args.testsz
preprocess = args.preprocess
split = args.split

datautils.set_seed(seed)
print("Generator info:")
print(json.dumps(args.__dict__, indent='\t'))

def preprocess_data(): 

    print("Start preprocess......")

    s = time.time()
    if data in [MCF7, MOLT4, PC3, SW620, NCIH23, OVCAR8, P388, SF295, SN12C, UACC257]: 
        graphs = datautils.preprocess_TUDataset(data)

    elif data in [PYGDBLP, PYGIMDB, PYGMAG, PYGFREEBASE]:
        graphs = datautils.preprocess_PyG(data)

    elif data == PDNS: 
        graphs = datautils.preprocess_kagglePDNS(data)

    elif data == RCDD:
        graphs = datautils.preprocess_kaggleRCDD(data)

    #elif data == TWIBOT:
    #    graphs = datautils.preprocess_Twibot(data)
    
    #graphs = Batch.from_data_list(graphs)
    unique_labels, counts = torch.unique(graphs.y, return_counts=True)
    n_types = torch.unique(graphs.node_type)
    e_types = torch.unique(graphs.edge_type)

    info = {"num_grpahs": len(graphs), 
            "avg_nodes": len(graphs.x) / len(graphs),
            "avg_edges": graphs.edge_index.shape[1] / len(graphs) / 2, 
            "feature_dim": graphs.x.shape[1], 
            "unique_labels": unique_labels.tolist(),
            "label_count": counts.tolist(), 
            "node_types": len(n_types),
            "edge_types": len(e_types)}

    print("Data info:")
    print(json.dumps(info, indent='\t'))

    print("Node types: {}".format(n_types))
    print("Edge types: {}".format(e_types))

    with open(os.path.join(data, 'info.json'), 'w') as f:
        json.dump(info, f, indent=4)
    e = time.time()

    print("Preprocess successfully, time cost: {}".format(e - s))

def split_data(): 
    print("Split......")

    s = time.time()
    graphs = torch.load(os.path.join(data, f'{data}.pt'))
    
    #graphs = Batch.from_data_list(graphs)

    train_indices, val_indices, test_indices = datautils.stratified_split(graphs.y.tolist(), trainsz, testsz, seed)
    
    train_path = os.path.join(data, "train.txt")
    val_path = os.path.join(data, "val.txt")
    test_path = os.path.join(data, "test.txt")

    np.savetxt(train_path, train_indices, fmt='%d')
    np.savetxt(val_path, val_indices, fmt='%d')
    np.savetxt(test_path, test_indices, fmt='%d')

    e = time.time()

    print("Finish split and save, time cost: {}".format(e - s))

if preprocess: 
    preprocess_data()
if split:
    split_data()
