from time import time
import logging
import os
import os.path as osp
import numpy as np
import pandas as pd
import time
import csv
import pickle as pkl

import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from torch_geometric.utils import degree
from torch.autograd import Variable

import random
from torch.optim.lr_scheduler import StepLR
from collections import defaultdict

from utils import stat_graph, split_class_graphs, align_graphs, get_logger, str_to_bool
from utils import two_graphons_mixup, universal_svd
from graphon_estimator import universal_svd
from models import GIN, GCN

from spectral_aug import spectral_noise, spectral_mask

import argparse

def prepare_dataset_x(dataset):
    if dataset[0].x is None:
        max_degree = 0
        degs = []
        for data in dataset:
            degs += [degree(data.edge_index[0], dtype=torch.long)]
            max_degree = max( max_degree, degs[-1].max().item() )
            data.num_nodes = int( torch.max(data.edge_index) ) + 1

        if max_degree < 2000:

            for data in dataset:
                degs = degree(data.edge_index[0], dtype=torch.long)
                data.x = F.one_hot(degs, num_classes=max_degree+1).to(torch.float)
        else:
            deg = torch.cat(degs, dim=0).to(torch.float)
            mean, std = deg.mean().item(), deg.std().item()
            for data in dataset:
                degs = degree(data.edge_index[0], dtype=torch.long)
                data.x = ( (degs - mean) / std ).view( -1, 1 )
    return dataset



def prepare_dataset_onehot_y(dataset):

    y_set = set()
    for data in dataset:
        y_set.add(int(data.y))
    num_classes = len(y_set)

    for data in dataset:
        data.y = F.one_hot(data.y, num_classes=num_classes).to(torch.float)[0]
    return dataset


def mixup_cross_entropy_loss(input, target, size_average=True):
    """Origin: https://github.com/moskomule/mixup.pytorch
    in PyTorch's cross entropy, targets are expected to be labels
    so to predict probabilities this loss is needed
    suppose q is the target and p is the input
    loss(p, q) = -\sum_i q_i \log p_i
    """
    assert input.size() == target.size()
    assert isinstance(input, Variable) and isinstance(target, Variable)
    loss = - torch.sum(input * target)
    return loss / input.size()[0] if size_average else loss

def train(model, train_loader):
    model.train()
    loss_all = 0
    graph_all = 0
    for data in train_loader:
        # print( "data.y", data.y )
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data.x, data.edge_index, data.batch)
        y = data.y.view(-1, num_classes)
        loss = mixup_cross_entropy_loss(output, y)
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        graph_all += data.num_graphs
        optimizer.step()
    loss = loss_all / graph_all
    return model, loss

def test(model, loader):
    model.eval()
    correct = 0
    total = 0
    loss = 0
    for data in loader:
        data = data.to(device)
        output = model(data.x, data.edge_index, data.batch)
        pred = output.max(dim=1)[1]
        y = data.y.view(-1, num_classes)
        loss += mixup_cross_entropy_loss(output, y).item() * data.num_graphs
        y = y.max(dim=1)[1]
        correct += pred.eq(y).sum().item()
        total += data.num_graphs
    acc = correct / total
    loss = loss / total
    return acc, loss


if __name__ == '__main__':
    
    torch.set_num_threads(8)

    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', type=str, default="./datasets")
    parser.add_argument('--dataset', type=str, default="IMDB-BINARY")
    parser.add_argument('--model', type=str, default="GIN")
    parser.add_argument('--epoch', type=int, default=800)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--num_hidden', type=int, default=64)
    parser.add_argument('--gmixup', type=str, default="False")
    parser.add_argument('--lam_range', type=str, default="[0.005, 0.01]")
    parser.add_argument('--aug_ratio', type=float, default=0.2)
    parser.add_argument('--aug_num', type=int, default=10)
    parser.add_argument('--gnn', type=str, default="gin")
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--log_screen', type=str, default="False")
    parser.add_argument('--save_loss', type=str_to_bool, default=False)
    parser.add_argument('--ge', type=str, default="MC")
    parser.add_argument('--spectral_aug', type=str, default="noise")
    parser.add_argument('--std_dev', type=float, default=1.0)
    parser.add_argument('--freq', type=str, default="high", help='low,high,None')
    parser.add_argument('--freq_ratio', type=float, default=0.2, help="how much to be auged")
    parser.add_argument('--aug_prob', type=float, default=0.5,help="aug with probability aug_prob")
    parser.add_argument('--n_exp', type=int, default=0)
    parser.add_argument('--mark', type=str, default="")
    parser.add_argument('--use_eigen_lock', action='store_false', default=False,help="whether use eigen lock to protect cpu running")


    args = parser.parse_args()
    
    folder_name = '{}-{}-{}-{}-{}-{}-{}-{}-{}-{}'.format(args.seed, args.spectral_aug, args.std_dev,
                                            args.aug_ratio,args.freq,args.freq_ratio, args.aug_prob,
                                            args.batch_size, args.lr, args.mark)
    log_dir = './logs/{}/{}/{}'.format(args.model, args.dataset, folder_name)
    result_dir = './results/{}/{}/'.format(args.model, args.dataset)
    os.makedirs(result_dir, exist_ok=True)
    logger, formatter = get_logger(log_dir, None, 'info_{}.log'.format(args.n_exp), level=logging.INFO)

    data_path = args.data_path
    dataset_name = args.dataset
    seed = args.seed
    lam_range = eval(args.lam_range)
    log_screen = eval(args.log_screen)
    gmixup = eval(args.gmixup)
    num_epochs = args.epoch
    spectral_aug = args.spectral_aug
    std_dev = args.std_dev
    freq = args.freq
    freq_ratio = args.freq_ratio
    aug_prob = args.aug_prob
    use_eigen_lock = args.use_eigen_lock

    assert not ( freq is None and freq_ratio < 1)
    assert (spectral_aug == "noise") or (spectral_aug == "mask" and std_dev ==1.0)

    num_hidden = args.num_hidden
    batch_size = args.batch_size
    learning_rate = args.lr
    ge = args.ge
    aug_ratio = args.aug_ratio
    aug_num = args.aug_num
    model = args.model

    if log_screen is True:
        ch = logging.StreamHandler()
        ch.setLevel(logging.DEBUG)
        ch.setFormatter(formatter)
        logger.addHandler(ch)


    logger.info('parser.prog: {}'.format(parser.prog))
    logger.info("args:{}".format(args))

    torch.manual_seed(seed)

    lock = build_lock(use_eigen_lock)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"runing device: {device}")

    dataset = TUDataset(data_path, name=dataset_name)
    dataset = list(dataset)

    for graph in dataset:
        graph.y = graph.y.view(-1)

    dataset = prepare_dataset_onehot_y(dataset)


    random.seed(seed)
    random.shuffle( dataset )

    train_nums = int(len(dataset) * 0.7)
    train_val_nums = int(len(dataset) * 0.8)
    
    avg_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density = stat_graph(dataset[: train_nums])
    logger.info(f"avg num nodes of training graphs: { avg_num_nodes }")
    logger.info(f"avg num edges of training graphs: { avg_num_edges }")
    logger.info(f"avg density of training graphs: { avg_density }")
    logger.info(f"median num nodes of training graphs: { median_num_nodes }")
    logger.info(f"median num edges of training graphs: { median_num_edges }")
    logger.info(f"median density of training graphs: { median_density }")

    resolution = int(median_num_nodes)

    if gmixup == True:
        class_graphs = split_class_graphs(dataset[:train_nums])
        graphons = []
        for label, graphs in class_graphs:

                logger.info(f"label: {label}, num_graphs:{len(graphs)}" )
                align_graphs_list, normalized_node_degrees, max_num, min_num = align_graphs(
                    graphs, padding=True, N=resolution)
                logger.info(f"aligned graph {align_graphs_list[0].shape}" )

                logger.info(f"ge: {ge}")
                graphon = universal_svd(align_graphs_list, threshold=0.2)
                graphons.append((label, graphon))


        for label, graphon in graphons:
            logger.info(f"graphon info: label:{label}; mean: {graphon.mean()}, shape, {graphon.shape}")
        
        num_sample = int( train_nums * aug_ratio / aug_num )
        lam_list = np.random.uniform(low=lam_range[0], high=lam_range[1], size=(aug_num,))

        random.seed(seed)
        new_graph = []
        for lam in lam_list:
            logger.info( f"lam: {lam}" )
            logger.info(f"num_sample: {num_sample}")
            two_graphons = random.sample(graphons, 2)
            new_graph += two_graphons_mixup(two_graphons, la=lam, num_sample=num_sample)
            logger.info(f"label: {new_graph[-1].y}")

        avg_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density = stat_graph(new_graph)
        logger.info(f"avg num nodes of new graphs: { avg_num_nodes }")
        logger.info(f"avg num edges of new graphs: { avg_num_edges }")
        logger.info(f"avg density of new graphs: { avg_density }")
        logger.info(f"median num nodes of new graphs: { median_num_nodes }")
        logger.info(f"median num edges of new graphs: { median_num_edges }")
        logger.info(f"median density of new graphs: { median_density }")

        dataset = new_graph + dataset
        logger.info( f"real aug ratio: {len( new_graph ) / train_nums }" )
        train_nums = train_nums + len( new_graph )
        train_val_nums = train_val_nums + len( new_graph )

    elif spectral_aug == "noise":
        random.seed(seed)

        augment_count = int(train_nums * aug_ratio)
        indices = np.random.choice(len(dataset), size=augment_count, replace=True)
        
        acquire_lock(lock)
        new_graph = spectral_noise(dataset, std_dev = std_dev, aug_prob = aug_prob, aug_freq = freq, aug_freq_ratio = freq_ratio, aug_indices = indices)
        release_lock(lock)

        avg_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density = stat_graph(new_graph)
        logger.info(f"avg num nodes of new graphs: { avg_num_nodes }")
        logger.info(f"avg num edges of new graphs: { avg_num_edges }")
        logger.info(f"avg density of new graphs: { avg_density }")
        logger.info(f"median num nodes of new graphs: { median_num_nodes }")
        logger.info(f"median num edges of new graphs: { median_num_edges }")
        logger.info(f"median density of new graphs: { median_density }")

        dataset = new_graph + dataset
        logger.info( f"real aug ratio: {len( new_graph ) / train_nums }" )
        train_nums = train_nums + len( new_graph )
        train_val_nums = train_val_nums + len( new_graph )

    elif spectral_aug == "mask":
        random.seed(seed)

        augment_count = int(train_nums * aug_ratio)
        indices = np.random.choice(len(dataset), size=augment_count, replace=True)
        
        acquire_lock(lock)
        new_graph = spectral_mask(dataset, aug_prob = aug_prob, aug_freq = freq, aug_freq_ratio = freq_ratio, aug_indices = indices)
        release_lock(lock)

        avg_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density = stat_graph(new_graph)
        logger.info(f"avg num nodes of new graphs: { avg_num_nodes }")
        logger.info(f"avg num edges of new graphs: { avg_num_edges }")
        logger.info(f"avg density of new graphs: { avg_density }")
        logger.info(f"median num nodes of new graphs: { median_num_nodes }")
        logger.info(f"median num edges of new graphs: { median_num_edges }")
        logger.info(f"median density of new graphs: { median_density }")

        dataset = new_graph + dataset
        logger.info( f"real aug ratio: {len( new_graph ) / train_nums }" )
        train_nums = train_nums + len( new_graph )
        train_val_nums = train_val_nums + len( new_graph )

    dataset = prepare_dataset_x( dataset )

    logger.info(f"num_features: {dataset[0].x.shape}" )
    logger.info(f"num_classes: {dataset[0].y.shape}"  )

    num_features = dataset[0].x.shape[1]
    num_classes = dataset[0].y.shape[0]

    train_dataset = dataset[:train_nums]
    random.shuffle(train_dataset)
    val_dataset = dataset[train_nums:train_val_nums]
    test_dataset = dataset[train_val_nums:]

    logger.info(f"train_dataset size: {len(train_dataset)}")
    logger.info(f"val_dataset size: {len(val_dataset)}")
    logger.info(f"test_dataset size: {len(test_dataset)}" )


    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)


    if model == "GIN":
        model = GIN(num_features=num_features, num_classes=num_classes, num_hidden=num_hidden).to(device)
    elif model == "GCN":
        model = GCN(num_features=num_features, num_classes=num_classes, num_hidden=num_hidden).to(device)
    else:
        logger.info(f"No model."  )


    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)
    scheduler = StepLR(optimizer, step_size=100, gamma=0.5)
    
    loss_curve = defaultdict(list)
    for epoch in range(1, num_epochs):
        model, train_loss = train(model, train_loader)
        train_acc = 0
        val_acc, val_loss = test(model, val_loader)
        test_acc, test_loss = test(model, test_loader)
        scheduler.step()

        logger.info('Epoch: {:03d}, Train Loss: {:.6f}, Val Loss: {:.6f}, Test Loss: {:.6f},  Val Acc: {: .6f}, Test Acc: {: .6f}'.format(
            epoch, train_loss, val_loss, test_loss, val_acc, test_acc))

        for key, l in zip(['train_loss', 'val_loss', 'test_loss', 'val_acc', 'test_acc' ], [train_loss, val_loss, test_loss, val_acc, test_acc]):
            loss_curve[key].append(l)
    
    #### save the result #######
    if args.save_loss:
        with open(result_dir + '/loss_{}-{}-{}.pkl'.format(args.spectral_aug, args.std_dev,time.strftime("%Y%m%d%H%M%S", time.localtime())),'wb') as f:
            pkl.dump(loss_curve,f)

    csv_path  = result_dir + '/{}.csv'.format(args.spectral_aug)
    print(csv_path)
    if not os.path.exists(csv_path):
        df = pd.DataFrame(columns = ['hp', 'end_time',
                                        'epoch','test_acc','val_acc',
                                        'train_loss','val_loss','test_loss'])
        df.to_csv(csv_path, index = False)
        
    with open(csv_path,'a+') as f:
        csv_write = csv.writer(f)
        data_row = [folder_name, time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()),
                    num_epochs, test_acc, val_acc,
                    train_loss, val_loss, test_loss
                    ]
        csv_write.writerow(data_row)

