import argparse
import numpy as np
import os
import time
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from variation_gnn._models import EPM_VAE_Loss, Modules_E, Modules_M
from variation_gnn._processor import graph_batch, preprocess_feat
from variation_gnn._utils import load_data_grl, separate_data, debug_print, debugger, get_logger

def unsup_pretrain(module_E, optimizer, train_graphs, args, epoch_idx, logger):
    """
    pretraining module_E with unsupervised loss
    """
    debug_ckpt = global_debug_ckpt + '/warm_up'

    module_E.train()
    tot_iters = int(np.ceil(len(train_graphs)/args.batch_size))
    pbar = tqdm(range(tot_iters), unit='batch')

    loss_epoch = 0.
    tot_idx = np.random.permutation(len(train_graphs))

    for batch_idx, pos in enumerate(pbar):
        # select a batch of graphs and convert them to graph_batch
        selected_idx = tot_idx[batch_idx * args.batch_size : (batch_idx+1)*args.batch_size]
        batch_graph = [train_graphs[idx] for idx in selected_idx]

        # time_1 = time.time()
        gb = graph_batch(batch_graph, args.device)
        # print(f'inital: {time.time()-time_1}')

        # return preds and inferred parameters
        batch_adj_E, batch_fts, batch_adj_labels_, batch_num_nodes_, batch_kl_w \
            = gb.norm_adj, gb.all_fts, gb.bin_adj_, gb.num_nodes_, gb.kl_weights

        # time_1 = time.time()
        preds, _, _, k, lbd = module_E(
            batch_adj_E, batch_fts, pretrain=True, lite_mode=args.lite_mode, num_nodes_=batch_num_nodes_
        )
        # print(f'forward: {time.time()-time_1}')

        # compute loss
        # time_1 = time.time()
        L_inf = EPM_VAE_Loss(batch_kl_w, lite_mode=args.lite_mode, device=args.device)
        loss = L_inf(preds, batch_adj_labels_, k, lbd) / (len(batch_graph))
        # print(f'loss: {time.time()-time_1}')

        # backprop
        # time_1 = time.time()
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            # debug_print(k.grad, ckpt = debug_ckpt + '/if optimizer is not None')
            optimizer.step()
        # print(f'backward: {time.time()-time_1}')


        loss_epoch += loss.detach().cpu().numpy()

        # report
        pbar.set_description('epoch: %d' % (epoch_idx))

    print("training loss = {:.4f}".format(loss_epoch / tot_iters))
    logger.info("training loss = {:.4f}".format(loss_epoch / tot_iters))


def CODGNN_EM(module_E, module_M, optim_E, optim_M, train_graphs, args, epoch_idx, logger):
    """
    train the model for 1 epoch
    """
    local_debug_ckpt = global_debug_ckpt + '/train'

    module_E.eval()

    total_iters = int(np.ceil(len(train_graphs)/args.batch_size))
    pbar = tqdm(range(total_iters), unit='batch')
    tot_idx = np.random.permutation(len(train_graphs))

    loss = 0.
    for batch_idx, pos in enumerate(pbar):
        # select a batch of graphs and convert them to graph_batch
        selected_idx = tot_idx[batch_idx * args.batch_size : (batch_idx+1)*args.batch_size]
        batch_graph = [train_graphs[idx] for idx in selected_idx]
        gb = graph_batch(batch_graph, args.device)

        # return preds and inferred parameters
        batch_norm_adj, batch_bin_adj, batch_fts, batch_labels, batch_adj_labels_, batch_num_nodes_ , batch_kl_w \
            = gb.norm_adj, gb.bin_adj, gb.all_fts, gb.labels, gb.bin_adj_, gb.num_nodes_, gb.kl_weights
        
        batch_adj_E = batch_norm_adj
        batch_adj_M = batch_norm_adj if module_M.aggregation == 'average' else batch_bin_adj

        preds, G_norm_, phi, k, lbd = module_E(
            batch_adj_E, batch_fts, args.module_M_num_communities, lite_mode=True, num_nodes_=batch_num_nodes_
        )
        batch_combined_fts = preprocess_feat(batch_fts, phi)

        for m_step in range(args.ME_ratio):

            module_M.train()
            batch_scores, _ = module_M(G_norm_, batch_adj_M, batch_combined_fts, batch_num_nodes_)
            loss_M = F.cross_entropy(batch_scores, batch_labels)

            if optim_M is not None:
                optim_M.zero_grad()
                loss_M.backward(retain_graph=True)
                optim_M.step()
        
        module_M.eval()
        module_E.train()

        # compute loss
        batch_scores, _ = module_M(G_norm_, batch_adj_M, batch_combined_fts, batch_num_nodes_)
        loss_E = F.cross_entropy(batch_scores, batch_labels)
        L_inf = EPM_VAE_Loss(batch_kl_w, lite_mode=True, device=args.device)
        loss_E += L_inf(preds, batch_adj_labels_, k, lbd) / args.batch_size 

        if optim_E is not None:
            optim_E.zero_grad()
            loss_E.backward()
            optim_E.step()

        loss += loss_E.detach().cpu().numpy()
    
    print(f"epoch {epoch_idx}, loss = {loss/total_iters:.4f}")
    logger.info(f"epoch {epoch_idx}, loss = {loss/total_iters:.4f}")


def eval(module_E, module_M, _graphs, args):
    module_E.eval()
    module_M.eval()

    tot_iters = int(np.ceil(len(_graphs) / args.batch_size))
    num_correct = 0

    for i in range(tot_iters):
        _batch = _graphs[i*args.batch_size: (i+1)*args.batch_size]
        gb = graph_batch(_batch, args.device)

        # return preds and inferred parameters
        batch_norm_adj, batch_bin_adj, batch_fts, batch_labels, batch_num_nodes_ \
            = gb.norm_adj, gb.bin_adj, gb.all_fts, gb.labels, gb.num_nodes_
        batch_adj_E = batch_norm_adj
        batch_adj_M = batch_norm_adj if module_M.aggregation == 'average' else batch_bin_adj

        _, G_norm_, phi, _, _ = module_E(
            batch_adj_E, batch_fts, args.module_M_num_communities, lite_mode=True, num_nodes_=batch_num_nodes_
        )
        batch_combined_fts = preprocess_feat(batch_fts, phi)
        batch_scores, _ = module_M(G_norm_, batch_adj_M, batch_combined_fts, batch_num_nodes_)

        batch_preds = batch_scores.max(1).indices
        batch_match = batch_preds.eq(batch_labels)
        num_correct += batch_match.sum().cpu().item()
    
    return num_correct / len(_graphs)


def execute(args):
    # make dir for log files
    if not os.path.exists(args.log_path):
        os.makedirs(args.log_path)

    time_format = '%Y-%m-%d %X'
    time_current = time.strftime(time_format).split(' ')

    log_file_name = f'{args.dataset}_{args.module_M_num_communities}_{time_current[0]}_{time_current[1].replace(":", "_")}_{args.fold_idx}.log'
    logger = get_logger(os.path.join(args.log_path, log_file_name))

    for attr, value in sorted(args.__dict__.items()):
        print(f'{attr.upper()}, {value}')
        logger.info(f'{attr.upper()}, {value}')


    # set up seeds and gpu device
    torch.manual_seed(0)
    np.random.seed(0)
    args.device = torch.device("cuda:" + str(args.gpu_idx)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    # read dataset, split into train / test
    graphs, num_classes = load_data_grl(args.dataset, args.degree_as_tag)
    train_graphs, test_graphs = separate_data(graphs, args.seed, args.fold_idx)

    edge_sum = 0
    node_sum = 0
    for graph in graphs:
        edge_sum += len(graph.g.edges)
        node_sum += len(graph.node_tags)
    # initialize model, optimizer
    ft_in = train_graphs[0].node_features.shape[1]
    print(f'ft_in:{ft_in}, egdes:{edge_sum/len(graphs)}, nodes:{node_sum/len(graphs)}')
    module_E = Modules_E(ft_in, args.module_E_h1_dim, args.module_E_h2_dim, args.module_E_dropout, lite_mode=args.lite_mode).to(args.device)
    module_M = Modules_M(
        ft_in + args.module_E_h2_dim,
        args.module_M_h_dims, 
        num_classes, 
        args.module_M_num_layers_Bcat, 
        args.module_M_num_layers_Acat, 
        args.module_M_num_communities,
        aggregation='sum'
    ).to(args.device)
    optim_E_wu = optim.Adam(module_E.parameters(), lr=args.module_E_wu_lr, weight_decay=args.module_E_l2coef)
    optim_E = optim.Adam(module_E.parameters(), lr=args.module_E_tr_lr, weight_decay=args.module_E_l2coef)
    optim_M = optim.Adam(module_M.parameters(), lr=args.module_M_tr_lr, weight_decay=args.module_M_l2coef)

    # run warm-up
    for wu_epoch in range(args.wu_epochs):
        avg_loss = unsup_pretrain(module_E, optim_E_wu, train_graphs, args, wu_epoch, logger)

    torch.save(module_E.state_dict(), './module_E_MUTAG_warm_up.pt')

    # run training
    best_epoch = 0
    best_acc = 0
    for tr_epoch in range(args.epochs):
        avg_loss = CODGNN_EM(module_E, module_M, optim_E, optim_M, train_graphs, args, tr_epoch, logger)
        train_acc = eval(module_E, module_M, train_graphs, args)
        test_acc = eval(module_E, module_M, test_graphs, args)
        if test_acc > best_acc:
            best_acc = test_acc
            best_epoch = tr_epoch
            torch.save(module_E.state_dict(), './module_E_MUTAG_train.pt')
            torch.save(module_M.state_dict(), './module_M_MUTAG_train.pt')

        print(f"train ACC: {train_acc:.4f}, test ACC: {test_acc:.4f}")
        print(f"Best test acc up to now: {best_acc:.4f}, at epoch {best_epoch}")
        logger.info(f"train ACC: {train_acc:.4f}, test ACC: {test_acc:.4f}")
        logger.info(f"Best test acc up to now: {best_acc:.4f}, at epoch {best_epoch}")

    print("model setup: num_heads = {}, head_layers (exclude input) = {}, tail_layers = {}".format(
        args.module_M_num_communities, args.module_M_num_layers_Bcat-1, args.module_M_num_layers_Acat
    ))
    logger.info("model setup: num_heads = {}, head_layers (exclude input) = {}, tail_layers = {}".format(
        args.module_M_num_communities, args.module_M_num_layers_Bcat-1, args.module_M_num_layers_Acat
    ))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='PyTorch graph convolutional neural net for whole-graph classification')
    parser.add_argument('--dataset', type=str, default="MUTAG",
                        help='name of dataset (default: MUTAG)')
    parser.add_argument('--lite_mode', type=bool, default=False,
                        help='whether to use the lite mode')

    # batch training hyperparameters
    parser.add_argument('--batch_size', type=int, default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--iters_per_epoch', type=int, default=50,
                        help='number of iterations per each epoch (default: 50)')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of epochs to train (default: 350)')
    parser.add_argument('--wu_epochs', type=int, default=1500,
                        help='number of epochs to warm-up (default: 1000)')
    parser.add_argument('--ME_ratio', type=int, default=1,
                        help='')

    # estimate module hyperparameters
    parser.add_argument('--module_E_h1_dim', type=int, default=200,
                        help='number of hidden units in the 1st layer of the inference module (default: 200)')
    parser.add_argument('--module_E_h2_dim', type=int, default=100,
                        help='number of hidden units in the 2nd layer of the inference module (default: 100)')
    parser.add_argument('--module_E_wu_lr', type=float, default=.01,
                        help='warming-up learning rate (default: 0.01)')
    parser.add_argument('--module_E_tr_lr', type=float, default=.001,
                        help='warming-up learning rate (default: 0.01)')
    parser.add_argument('--module_E_dropout', type=float, default=0.5,
                        help='inference module dropout rate (default: 0.5)')
    parser.add_argument('--module_E_l2coef', type=float, default=0.,
                        help='inference module l2 reg coef (default: 0.)')

    # maximize module hyperparameters
    parser.add_argument('--module_M_h_dims', type=int, default=64,
                        help='number of hidden units in the 1st layer of the inference module (default: 64)')
    parser.add_argument('--module_M_num_communities', type=int, default=4,
                        help='')
    parser.add_argument('--module_M_num_layers_Bcat', type=int, default=3,
                        help='')
    parser.add_argument('--module_M_num_layers_Acat', type=int, default=2,
                        help='')
    parser.add_argument('--module_M_tr_lr', type=float, default=.001,
                        help='warming-up learning rate (default: 0.01)')
    parser.add_argument('--module_M_dropout', type=float, default=0.5,
                        help='inference module dropout rate (default: 0.5)')
    parser.add_argument('--module_M_l2coef', type=float, default=0.,
                        help='inference module l2 reg coef (default: 0.)')

    # global randomness
    parser.add_argument('--seed', type=int, default=0,
                        help='random seed for splitting the dataset into 10 (default: 0)')

    # IOs
    parser.add_argument('--degree_as_tag', action="store_true",
                        help='let the input node features be the degree of nodes (heuristics for unlabeled graph)')
    parser.add_argument('--fold_idx', type=int, default=8,
                        help='the index of fold in 10-fold validation. Should be less then 10.')
    parser.add_argument('--gpu_idx', type=int, default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--filename', type=str, default="",
                        help='output file')
    parser.add_argument('--log_path', type=str, default='./log/')


    args = parser.parse_args()

    global_debug_ckpt = 'epm_main'

    execute(args)
