import os
import numpy as np
import torch
import dgl
import networkx as nx
import argparse
import random
import time
from L2 import Regularization
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import tu
from dgl.data import CoraGraphDataset
from dgl.data import CoraFullDataset
from model.encoder import DiffPool
from data_utils import pre_process
import shutil
import warnings
from torch import Tensor

warnings.filterwarnings('ignore')

global_train_time_per_epoch = []


def getMasks(model, w_ratio):
    # w_ratio代表剩余网络参数的比例
    unmasks = []
    masks = []
    for name, param in model.named_parameters():
        if 'weight' in name:
            if 'norms' in name:
                continue
            if 'bn' in name:
                continue
            print(name)
            mask = torch.zeros_like(param)
            unmask = torch.ones_like(param)
            shape0 = mask.shape[0]
            shape1 = mask.shape[1]
            mask = mask.reshape(-1)
            indices = np.random.choice(np.arange(torch.tensor(mask.shape).item()), replace=False,
                                       size=int(torch.tensor(mask.shape).item() * (1 - w_ratio)))
            mask[indices] = 1
            mask = mask.reshape(shape0, shape1)
            unmask = unmask - mask
            masks.append(mask)
            unmasks.append(unmask)
    return masks, unmasks


def get_new_mask(length, ratio):
    # ratio代表掩码的概率
    mask = torch.zeros((length, length))
    unmask = torch.ones((length, length))
    mask = mask.reshape(-1)
    indices = np.random.choice(np.arange(torch.tensor(mask.shape).item()), replace=False,
                               size=int(torch.tensor(mask.shape).item() * (1 - ratio)))
    mask[indices] = 1
    mask = mask.reshape(length, length)
    unmask = unmask - mask
    return mask, unmask


def get_weight_decays(count):
    base_weight_decay = 5e-4
    add_mul = 1e-7
    weight_decays = [base_weight_decay + add_mul * i for i in range(count)]
    return weight_decays


def arg_parse():
    '''
    argument parser
    '''
    parser = argparse.ArgumentParser(description='DiffPool arguments')
    parser.add_argument('--dataset', dest='dataset', help='Input Dataset')
    parser.add_argument(
        '--pool_ratio',
        dest='pool_ratio',
        type=float,
        help='pooling ratio')
    parser.add_argument(
        '--num_pool',
        dest='num_pool',
        type=int,
        help='num_pooling layer')
    parser.add_argument('--no_link_pred', dest='linkpred', action='store_false',
                        help='switch of link prediction object')
    parser.add_argument('--cuda', dest='cuda', type=int, help='switch cuda')
    parser.add_argument('--lr', dest='lr', type=float, help='learning rate')
    parser.add_argument(
        '--clip',
        dest='clip',
        type=float,
        help='gradient clipping')
    parser.add_argument(
        '--batch-size',
        dest='batch_size',
        type=int,
        help='batch size')
    parser.add_argument('--epochs', dest='epoch', type=int,
                        help='num-of-epoch')
    parser.add_argument('--train-ratio', dest='train_ratio', type=float,
                        help='ratio of trainning dataset split')
    parser.add_argument('--test-ratio', dest='test_ratio', type=float,
                        help='ratio of testing dataset split')
    parser.add_argument('--num_workers', dest='n_worker', type=int,
                        help='number of workers when dataloading')
    parser.add_argument('--gc-per-block', dest='gc_per_block', type=int,
                        help='number of graph conv layer per block')
    parser.add_argument('--bn', dest='bn', action='store_const', const=True,
                        default=True, help='switch for bn')
    parser.add_argument('--dropout', dest='dropout', type=float,
                        help='dropout rate')
    parser.add_argument('--bias', dest='bias', action='store_const',
                        const=True, default=True, help='switch for bias')
    parser.add_argument(
        '--save_dir',
        dest='save_dir',
        help='model saving directory: SAVE_DICT/DATASET')
    parser.add_argument('--load_epoch', dest='load_epoch', type=int, help='load trained model params from\
                         SAVE_DICT/DATASET/model-LOAD_EPOCH')
    parser.add_argument('--data_mode', dest='data_mode', help='data\
                        preprocessing mode: default, id, degree, or one-hot\
                        vector of degree number', choices=['default', 'id', 'deg',
                                                           'deg_num'])

    parser.set_defaults(dataset='DD',
                        pool_ratio=0.15,
                        num_pool=2,
                        cuda=0,
                        lr=1e-3,
                        clip=2.0,
                        batch_size=20,
                        epoch=2000,
                        train_ratio=0.7,
                        test_ratio=0.1,
                        n_worker=1,
                        gc_per_block=3,
                        dropout=0.0,
                        method='diffpool',
                        bn=True,
                        bias=True,
                        save_dir="./model_param",
                        load_epoch=-1,
                        data_mode='default')
    return parser.parse_args()


def prepare_data(dataset, prog_args, train=False, pre_process=None):
    '''
    preprocess TU dataset according to DiffPool's paper setting and load dataset into dataloader
    '''
    if train:
        shuffle = True
    else:
        shuffle = False

    if pre_process:
        pre_process(dataset, prog_args)

    # dataset.set_fold(fold)
    return dgl.dataloading.GraphDataLoader(dataset,
                                           batch_size=prog_args.batch_size,
                                           shuffle=shuffle,
                                           num_workers=prog_args.n_worker)


def print_model(model):
    for name, param in model.named_parameters():
        # if 'weight' in name:
        #     if 'norms' in name:
        #         continue
        #     if 'bn' in name:
        #         continue
        #     print(param)
        if 'mask' in name:
            print(param)


def one_shot_prune(model, unmasks):
    my_count = 0
    for name, param in model.named_parameters():
        with torch.no_grad():
            if 'weight' in name:
                if 'norms' in name:
                    continue
                if 'bn' in name:
                    continue
                param[:] = param * unmasks[my_count]
                my_count += 1


def run_fine_tune(mask, model, optimizer, count, prog_args, dataloader, weight_decays, masks, unmasks, logger,
                  val_dataset=None):
    '''
        training function
        '''
    early_stopping_logger = {"best_epoch": -1, "val_acc": -1}
    if logger is not None and prog_args.save_dir is not None:
        if logger.get('best_epoch') != -1:
            print("load..........")
            model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset
                                             + "/model.iter-" + str(logger['best_epoch'])))
    if prog_args.cuda > 0:
        torch.cuda.set_device(0)

    for epoch in range(1000):
        begin_time = time.time()
        model.train()
        accum_correct = 0
        total = 0
        print("\nEPOCH ###### {} ######".format(epoch))
        computation_time = 0.0
        for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader):
            for (key, value) in batch_graph.ndata.items():
                batch_graph.ndata[key] = value.float()
            graph_labels = graph_labels.long()
            if torch.cuda.is_available():
                batch_graph = batch_graph.to(torch.cuda.current_device())
                graph_labels = graph_labels.cuda()
            model.zero_grad()
            compute_start = time.time()

            ypred = model(batch_graph)

            indi = torch.argmax(ypred, dim=1)
            correct = torch.sum(indi == graph_labels).item()
            accum_correct += correct
            total += graph_labels.size()[0]

            loss = model.loss(ypred, graph_labels)
            reg_loss = Regularization(model, weight_decays[int(count / 10)], masks, p=2)
            pool_loss = cau_loss(mask, model, weight_decays[int(count / 10)])

            my_reg = reg_loss(model)
            loss = loss + my_reg + pool_loss

            loss.backward()
            batch_compute_time = time.time() - compute_start
            computation_time += batch_compute_time
            nn.utils.clip_grad_norm_(model.parameters(), prog_args.clip)
            optimizer.step()

        count = 0
        for name, param in model.named_parameters():
            with torch.no_grad():
                if 'weight' in name:
                    if 'norms' in name:
                        continue
                    if 'bn' in name:
                        continue
                    param[:] = param * unmasks[count]
                    count += 1

        train_accu = accum_correct / total
        print("train accuracy for this epoch {} is {:.2f}%".format(epoch,
                                                                   train_accu * 100))
        elapsed_time = time.time() - begin_time
        print("loss {:.4f} with epoch time {:.4f} s & computation time {:.4f} s ".format(
            loss.item(), elapsed_time, computation_time))
        global_train_time_per_epoch.append(elapsed_time)
        if val_dataset is not None:
            result = evaluate(val_dataset, model, prog_args)
            print("validation  accuracy {:.2f}%".format(result * 100))
            if result >= early_stopping_logger['val_acc'] and result <= \
                    train_accu:
                early_stopping_logger.update(best_epoch=epoch, val_acc=result)
                if prog_args.save_dir is not None:
                    torch.save(model.state_dict(), prog_args.save_dir + "/" + prog_args.dataset
                               + "/model.iter-" + str(early_stopping_logger['best_epoch']))
            print("best epoch is EPOCH {}, val_acc is {:.2f}%".format(early_stopping_logger['best_epoch'],
                                                                      early_stopping_logger['val_acc'] * 100))
        torch.cuda.empty_cache()
    return early_stopping_logger


def graph_classify_task(dataset_name, prog_args, g_ratio, w_ratio):
    '''
    perform graph classification task
    '''
    mask, unmask = get_new_mask(18, 0.2)
    prog_args.pool_ratio = g_ratio
    prog_args.dataset = dataset_name
    dataset = tu.LegacyTUDataset(name=prog_args.dataset)
    train_size = int(prog_args.train_ratio * len(dataset))
    test_size = int(prog_args.test_ratio * len(dataset))
    val_size = int(len(dataset) - train_size - test_size)

    dataset_train, dataset_val, dataset_test = torch.utils.data.random_split(
        dataset, (train_size, val_size, test_size))
    train_dataloader = prepare_data(dataset_train, prog_args, train=False,
                                    pre_process=pre_process)
    val_dataloader = prepare_data(dataset_val, prog_args, train=False,
                                  pre_process=pre_process)
    test_dataloader = prepare_data(dataset_test, prog_args, train=False,
                                   pre_process=pre_process)
    input_dim, label_dim, max_num_node = dataset.statistics()

    print("++++++++++STATISTICS ABOUT THE DATASET")
    print("dataset feature dimension is", input_dim)
    print("dataset label dimension is", label_dim)
    print("the max num node is", max_num_node)
    print("number of graphs is", len(dataset))
    # assert len(dataset) % prog_args.batch_size == 0, "training set not divisible by batch size"

    hidden_dim = 64  # used to be 64
    embedding_dim = 64

    # calculate assignment dimension: pool_ratio * largest graph's maximum
    # number of nodes  in the dataset
    assign_dim = int(max_num_node * prog_args.pool_ratio)
    print("++++++++++MODEL STATISTICS++++++++")
    print("model hidden dim is", hidden_dim)
    print("model embedding dim for graph instance embedding", embedding_dim)
    print("initial batched pool graph dim is", assign_dim)
    activation = F.relu

    # initialize model
    # 'diffpool' : diffpool
    model = DiffPool(input_dim,
                     hidden_dim,
                     embedding_dim,
                     label_dim,
                     activation,
                     prog_args.gc_per_block,
                     prog_args.dropout,
                     prog_args.num_pool,
                     prog_args.linkpred,
                     prog_args.batch_size,
                     'meanpool',
                     assign_dim,
                     prog_args.pool_ratio)

    if prog_args.load_epoch >= 0 and prog_args.save_dir is not None:
        model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset
                                         + "/model.iter-" + str(prog_args.load_epoch)))

    print("model init finished")
    print("MODEL:::::::", prog_args.method)
    if prog_args.cuda:
        model = model.cuda()

    count = 1000
    weight_decays = get_weight_decays(count)
    masks, unmasks = getMasks(model, w_ratio)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()), lr=0.001)
    # print_model(model)
    logger = train(
        mask,
        train_dataloader,
        model,
        optimizer,
        prog_args,
        weight_decays,
        count,
        masks,
        val_dataset=val_dataloader)

    ori = evaluate(test_dataloader, model, prog_args, logger)
    one_shot_prune(model, unmasks)
    new_logger = run_fine_tune(mask, model, optimizer, count, prog_args, train_dataloader, weight_decays, masks, unmasks,
                               logger)
    # print_model(model)
    new = evaluate(test_dataloader, model, prog_args, new_logger)
    return ori, new


def cau_loss(mask, model, weight_decay):
    reg_loss = 0
    for name, w in model.named_parameters():
        if 'mask' in name:
            temp = np.array(Tensor.cpu(w.data) * Tensor.cpu(mask))
            new_data = torch.from_numpy(temp).cuda()
            l2_reg = torch.norm(new_data, p=2)
            reg_loss = reg_loss + l2_reg
    reg_loss = weight_decay * reg_loss
    return reg_loss


def train(mask, dataset, model, optimizer, prog_args, weight_decays, count, masks, val_dataset=None, same_feat=True):
    '''
    training function
    '''
    dir = prog_args.save_dir + "/" + prog_args.dataset
    if os.path.exists(dir):
        print("remove...")
        shutil.rmtree(dir)
    if not os.path.exists(dir):
        os.makedirs(dir)
    dataloader = dataset

    early_stopping_logger = {"best_epoch": -1, "val_acc": -1}

    if prog_args.cuda > 0:
        torch.cuda.set_device(0)
    for epoch in range(prog_args.epoch):
        begin_time = time.time()
        model.train()
        accum_correct = 0
        total = 0
        print("\nEPOCH ###### {} ######".format(epoch))
        computation_time = 0.0
        for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader):
            for (key, value) in batch_graph.ndata.items():
                batch_graph.ndata[key] = value.float()
            graph_labels = graph_labels.long()
            if torch.cuda.is_available():
                batch_graph = batch_graph.to(torch.cuda.current_device())
                graph_labels = graph_labels.cuda()
            model.zero_grad()
            compute_start = time.time()

            ypred = model(batch_graph)

            indi = torch.argmax(ypred, dim=1)
            correct = torch.sum(indi == graph_labels).item()
            accum_correct += correct
            total += graph_labels.size()[0]

            loss = model.loss(ypred, graph_labels)
            reg_loss = Regularization(model, weight_decays[int(count / 10)], masks, p=2)

            pool_loss = cau_loss(mask, model, weight_decays[int(count / 10)])

            my_reg = reg_loss(model)
            loss = loss + my_reg + pool_loss

            loss.backward()
            batch_compute_time = time.time() - compute_start
            computation_time += batch_compute_time
            nn.utils.clip_grad_norm_(model.parameters(), prog_args.clip)
            optimizer.step()

        train_accu = accum_correct / total
        # print("train accuracy for this epoch {} is {:.2f}%".format(epoch,
        #                                                            train_accu * 100))
        elapsed_time = time.time() - begin_time
        # print("loss {:.4f} with epoch time {:.4f} s & computation time {:.4f} s ".format(
        #     loss.item(), elapsed_time, computation_time))
        global_train_time_per_epoch.append(elapsed_time)
        if val_dataset is not None:
            result = evaluate(val_dataset, model, prog_args)
            print("validation  accuracy {:.2f}%".format(result * 100))
            if result >= early_stopping_logger['val_acc'] and result <= \
                    train_accu:
                early_stopping_logger.update(best_epoch=epoch, val_acc=result)
                if prog_args.save_dir is not None:
                    torch.save(model.state_dict(), prog_args.save_dir + "/" + prog_args.dataset
                               + "/model.iter-" + str(early_stopping_logger['best_epoch']))
            print("best epoch is EPOCH {}, val_acc is {:.2f}%".format(early_stopping_logger['best_epoch'],
                                                                      early_stopping_logger['val_acc'] * 100))
        torch.cuda.empty_cache()
        # print_model(model)
    return early_stopping_logger


def evaluate(dataloader, model, prog_args, logger=None):
    '''
    evaluate function
    '''
    if logger is not None and prog_args.save_dir is not None:
        print(logger, type(logger))
        if logger.get('best_epoch') != -1:
            print("load..........")
            model.load_state_dict(torch.load(prog_args.save_dir + "/" + prog_args.dataset
                                             + "/model.iter-" + str(logger['best_epoch'])))
    model.eval()
    correct_label = 0
    with torch.no_grad():
        for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
            for (key, value) in batch_graph.ndata.items():
                batch_graph.ndata[key] = value.float()
            graph_labels = graph_labels.long()
            if torch.cuda.is_available():
                batch_graph = batch_graph.to(torch.cuda.current_device())
                graph_labels = graph_labels.cuda()
            ypred = model(batch_graph)
            indi = torch.argmax(ypred, dim=1)
            correct = torch.sum(indi == graph_labels)
            correct_label += correct.item()
    result = correct_label / (len(dataloader) * prog_args.batch_size)
    return result


def main():
    '''
    main
    '''
    prog_args = arg_parse()
    print(prog_args)

    datasets = ['ENZYMES', 'DD']
    # datasets = ['ENZYMES']
    for dataset in datasets:
        i_s = []
        ori_accs = []
        new_accs = []
        times = []
        memorys = []
        for i in range(1, 30):
            prog_args = arg_parse()
            global_train_time_per_epoch = []
            i = i * 0.05
            ori, new = graph_classify_task(dataset, prog_args, g_ratio=0.15, w_ratio=1 - i)
            i = round(i, 3)
            ori = round(ori, 3)
            new = round(new, 3)
            with open("temp.txt", "a") as r:
                r.write(str(1 - i) + " " + str(ori) + " " + str(new))
            # train_time_per_epoch = sum(global_train_time_per_epoch) / len(global_train_time_per_epoch)
            max_memory = torch.cuda.max_memory_allocated(0) / (1024 * 1024)
            max_memory = round(max_memory, 3)
            i_s.append(i)
            ori_accs.append(ori)
            new_accs.append(new)
            # times.append(train_time_per_epoch)
            memorys.append(max_memory)
        with open("classification_weight.txt", "a") as f:
            f.write(dataset + '\n')
            f.write(str(i_s)[1:-1] + '\n')
            f.write(str(ori_accs)[1:-1] + '\n')
            f.write(str(new_accs)[1:-1] + '\n')
            # f.write(str(times)[1:-1] + '\n')
            f.write(str(memorys)[1:-1] + '\n')
            f.write('\n')
            torch.cuda.empty_cache()
        # torch.cuda.empty_cache()
        # torch.cuda.empty_cache()
        # torch.cuda.empty_cache()
        # torch.cuda.empty_cache()
        # ori_accs = []
        # new_accs = []
        # times = []
        # memorys = []
        # for i in range(1, 3):
        #     prog_args = arg_parse()
        #     global_train_time_per_epoch = []
        #     ori, new = graph_classify_task(dataset, prog_args, g_ratio=i * 0.05, w_ratio=1)
        #     # train_time_per_epoch = sum(global_train_time_per_epoch) / len(global_train_time_per_epoch)
        #     max_memory = torch.cuda.max_memory_allocated(0) / (1024 * 1024)
        #     ori_accs.append(ori)
        #     new_accs.append(new)
        #     # times.append(train_time_per_epoch)
        #     memorys.append(max_memory)
        #     with open("classification_graph.txt", "a") as f:
        #         f.write(dataset + '\n')
        #         f.write(str(i * 0.05) + '\n')
        #         f.write(str(ori_accs)[1:-1] + '\n')
        #         f.write(str(new_accs)[1:-1] + '\n')
        #         # f.write(str(times)[1:-1] + '\n')
        #         f.write(str(memorys)[1:-1] + '\n')
        #         f.write('\n')


if __name__ == "__main__":
    main()
