from time import time
import logging
import os.path as osp
import numpy as np
import pickle
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from torch_geometric.utils import degree
import random
from torch.optim.lr_scheduler import StepLR
from utils import stat_graph
from utils import two_graphons_mixup
from utils import prepare_dataset_onehot_y, prepare_dataset_x
from utils import train, test
from models import GIN
from get_clusters import *
import argparse


logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(levelname)s: - %(message)s', datefmt='%Y-%m-%d')
logging.getLogger('matplotlib').setLevel(logging.WARNING)



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', type=str, default="./")
    parser.add_argument('--dataset', type=str, default="IMDB-BINARY")
    parser.add_argument('--model', type=str, default="GIN")
    parser.add_argument('--epoch', type=int, default=800)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--num_hidden', type=int, default=64)
    parser.add_argument('--lam_range', type=str, default="[0.1,0.2]") # [0.1,0.2]
    parser.add_argument('--aug_ratio', type=float, default=0.05)
    parser.add_argument('--aug_num', type=int, default=10)
    parser.add_argument('--log_screen', type=str, default="True")
    parser.add_argument('--gmixup', type=str, default="True")
    parser.add_argument('--ge', type=str, default="ISGL")
    parser.add_argument('--seed', type=int, default=1314)
    parser.add_argument('--n_epochs_inr', type=int, default=20)
    
    args = parser.parse_args()

    data_path = args.data_path
    dataset_name = args.dataset
    seed = args.seed
    lam_range = eval(args.lam_range)
    log_screen = eval(args.log_screen)
    gmixup = eval(args.gmixup)
    num_epochs = args.epoch

    num_hidden = args.num_hidden
    batch_size = args.batch_size
    learning_rate = args.lr
    ge = args.ge
    aug_ratio = args.aug_ratio
    aug_num = args.aug_num
    model = args.model

    setup_seed(seed)

    if log_screen is True:
        ch = logging.StreamHandler()
        ch.setLevel(logging.DEBUG)
        ch.setFormatter(formatter)
        logger.addHandler(ch)


    logger.info('parser.prog: {}'.format(parser.prog))
    logger.info("args:{}".format(args))

    device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
    logger.info(f"runing device: {device}")

    path = osp.join(data_path, dataset_name)
    dataset = TUDataset(path, name=dataset_name)
    dataset = list(dataset)

    print(dataset[0].x)


    for graph in dataset:
        graph.y = graph.y.view(-1)

    dataset = prepare_dataset_onehot_y(dataset)

    avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(dataset)
    logger.info(f"Info of dataset: {dataset_name}")
    logger.info(f"num graphs: { len(dataset) }")
    logger.info(f"avg num nodes of graphs: { avg_num_nodes }")
    logger.info(f"max num nodes of graphs: { max_num_nodes }")
    logger.info(f"min num nodes of graphs: { min_num_nodes }")
    logger.info(f"avg num edges of graphs: { avg_num_edges }")
    logger.info(f"avg density of graphs: { avg_density }")
    logger.info(f"median num nodes of graphs: { median_num_nodes }")
    logger.info(f"median num edges of graphs: { median_num_edges }")
    logger.info(f"median density of graphs: { median_density }")
    logger.info(f"std num nodes of graphs: { std_num_nodes }")

    # to get the index of training graphs
    indexed_dataset = list(enumerate(dataset))  # [(0, item0), (1, item1), ...]
    random.shuffle(indexed_dataset) 
    dataset = [data for _, data in indexed_dataset]  # [item0, item1, ...]
    idx = [i for i, _ in indexed_dataset]  # [0, 1, 2, ...]
    train_nums = int(len(dataset) * 0.7)
    train_val_nums = int(len(dataset) * 0.8)
    train_idx = idx[:train_nums]
    tain_val_idx = idx[:train_val_nums]

    
    if gmixup == True:
        
        # with open("Models/graphons_gmixup_" + dataset_name + ".pkl", 'rb') as f:
        #     graphons_ige = pickle.load(f) # of format: [(label, graphon, trained_inr), ...]
        # # keep training graphs
        # graphons = [graphons_ige[i] for i in train_idx] 
        J = 20
        nCluster = -1
        graphons = build_graphons_for_trainval(tain_val_idx, dataset_name, seed, J, nCluster)
        

        # num_sample = int( train_nums * aug_ratio / aug_num )
        num_sample = int( train_nums * aug_ratio)
        lam_list = np.random.uniform(low=lam_range[0], high=lam_range[1], size=(num_sample,)) # aug_num

        new_graph = []
        for lam in lam_list:
            logger.info( f"lam: {lam}" )
            logger.info(f"num_sample: {num_sample}")

            two_graphons = random.sample(graphons, 2)
            while two_graphons[0][0][0] == two_graphons[1][0][0] : # ensure two graphs that we choose are from different classses
                two_graphons = random.sample(graphons, 2)

            upper_bound = avg_num_nodes
            # lower_bound = avg_num_nodes
            lower_bound = int(random.randint(int(min_num_nodes), int(avg_num_nodes)))
            # lower_bound = int(random.randint(int(10), int(50)))
            new_graph += two_graphons_mixup(two_graphons, la=lam, num_sample=1, ge=ge, resolution=[lower_bound, upper_bound])
            logger.info(f"label: {new_graph[-1].y}")

        avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(new_graph)
        logger.info(f"avg num nodes of new graphs: { avg_num_nodes }")
        logger.info(f"min num nodes of new graphs: { min_num_nodes }")
        logger.info(f"max num nodes of new graphs: { max_num_nodes }")
        logger.info(f"avg num edges of new graphs: { avg_num_edges }")
        logger.info(f"avg density of new graphs: { avg_density }")
        logger.info(f"median num nodes of new graphs: { median_num_nodes }")
        logger.info(f"median num edges of new graphs: { median_num_edges }")
        logger.info(f"median density of new graphs: { median_density }")

        dataset = new_graph + dataset
        logger.info( f"real aug ratio: {len( new_graph ) / train_nums }" )
        train_nums = train_nums + len( new_graph )
        train_val_nums = train_val_nums + len( new_graph )

    dataset = prepare_dataset_x(dataset)

    logger.info(f"num_features: {dataset[0].x.shape}" )
    logger.info(f"num_classes: {dataset[0].y.shape}"  )

    num_features = dataset[0].x.shape[1]
    num_classes = dataset[0].y.shape[0]

    setup_seed(seed)

    train_dataset = dataset[:train_nums]
    random.shuffle(train_dataset)
    val_dataset = dataset[train_nums:train_val_nums]
    test_dataset = dataset[train_val_nums:]

    logger.info(f"train_dataset size: {len(train_dataset)}")
    logger.info(f"val_dataset size: {len(val_dataset)}")
    logger.info(f"test_dataset size: {len(test_dataset)}" )


    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, exclude_keys=['edge_attr'],)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, exclude_keys=['edge_attr'],)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, exclude_keys=['edge_attr'],)

    model = GIN(num_features=num_features, num_classes=num_classes, num_hidden=num_hidden).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)
    scheduler = StepLR(optimizer, step_size=100, gamma=0.5)

    last_10_epoch_acc = []
    best_val_acc = 0      
    for epoch in range(1, num_epochs):
        model, train_loss = train(model, train_loader, optimizer, num_classes)
        train_acc = 0
        val_acc, val_loss, emb_out_val = test(model, val_loader, num_classes)
        test_acc, test_loss, emb_out_test = test(model, test_loader, num_classes)
        scheduler.step()

        if val_acc >= best_val_acc:
            best_epoch = epoch
            best_val_acc = val_acc
            best_test_acc = test_acc       
            
        if epoch > num_epochs - 10:
            last_10_epoch_acc.append(test_acc)

        logger.info('Epoch: {:03d}, Train Loss: {:.6f}, Val Loss: {:.6f}, Test Loss: {:.6f},  Val Acc: {: .6f}, Test Acc: {: .6f}'.format(
            epoch, train_loss, val_loss, test_loss, val_acc, test_acc))
        

    logger.info(f"last 10 epoch average acc: {np.round(np.mean(np.array(last_10_epoch_acc)),3)}")
    logger.info(f"best test acc based on val: {best_test_acc} at epoch {best_epoch}")
        