import argparse
import os
import random
import time

import networkx as nx
import numpy as np
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from loader_pyg import DataLoader_pyg

import gengraph
import load_data
import models
import utils.featgen as featgen
import utils.io_utils as io_utils
import utils.parser_utils as parser_utils
import models_pyg

from datasets import RandomDataset, RandomBasisDataset, RandomDatasetPyG, PredefinedDataset
from subgraph_matching import SubgraphMatchingRandom
from subgraph_matching_pyg import SubgraphMatchingRandomPyG
from utils.synthetic_structsim import attach_query_graph


def to_numpy_matrix(G, edge_type=False):
    adj = nx.to_numpy_matrix(G).astype(int)
    if edge_type:
        n_vals = adj.max() + 1
        # this creates the 3D adj: n x n x (edge_types+1). The +1 is due to entry 0 (no edge)
        # The edges types are from 1, 2, ...
        adj_categorical = np.eye(n_vals)[adj]
        # remove the dim corresponding to edge type 0 ( no edge )
        adj_categorical = adj_categorical[:, :, 1:]
        # move the edge type dimension to the first dim
        return adj_categorical.transpose(2, 0, 1)
    return adj

def syn_task1(args, writer=None):
    # data, check if harder_training is setup

    if args.harder_training:
        print("Setting up harder_training data")
        G, G_query, labels, query_labels, name = gengraph.gen_syn1(
            feature_generator=featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float)), harder_training=True)
    else:
        print("Not setting up harder_training data")
        G, G_query, labels, query_labels, name = gengraph.gen_syn1(
            feature_generator=featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float)))

    num_classes = 2
    vars(args)["edge_dim"] = 1
    if args.method == 'no training':
        print('Method: random init')
    elif args.method == 'base':
        print('Method:', args.method)

        if args.order_embeddings:
            model = models.GcnEncoderMatching(args.input_dim, args.hidden_dim,
            args.output_dim, num_classes, args.num_gc_layers, bn=args.bn, args=args, order_embeddings=True)
        else:
            model = models.GcnEncoderMatching(args.input_dim, args.hidden_dim,
            args.output_dim, num_classes, args.num_gc_layers, bn=args.bn, args=args)

        if args.gpu:
            model = model.cuda()

        adj = np.expand_dims(to_numpy_matrix(G, edge_type=True), axis=0)
        feat = [G.node[u]['feat'] for u in G.nodes()]
        feat = np.expand_dims(np.vstack(feat), axis=0)

        query_adj = np.expand_dims(to_numpy_matrix(G_query, edge_type=True), axis=0)
        query_feat = [G_query.node[u]['feat'] for u in G_query.nodes()]
        query_feat = np.expand_dims(np.vstack(query_feat), axis=0)
        query_labels = [query_labels]

        matching = SubgraphMatching(model, adj, feat, labels, query_adj, query_feat, query_labels, args)
        matching.train_subgraph_match(args, writer=writer)
        matching.test_subgraph_match(args)
    elif args.method == 'manual_label':
        # manually add query graphs to original
        G, G_query, labels, query_labels, name = gengraph.gen_syn1(nb_shapes = 10,
                feature_generator=featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float)))
        G, labels, nb_query = attach_query_graph(G, labels,
            feature_generator=featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float)),
            random_edges = 2, nb_query = 20)
        model = models.GcnEncoderMatching(args.input_dim, args.hidden_dim, args.output_dim, num_classes,
                                       args.num_gc_layers, bn=args.bn, args=args)
        if args.gpu:
            model = model.cuda()

        adj = np.expand_dims(to_numpy_matrix(G, edge_type=True), axis=0)
        feat = [G.node[u]['feat'] for u in G.nodes()]
        feat = np.expand_dims(np.vstack(feat), axis=0)

        query_adj = np.expand_dims(to_numpy_matrix(G_query, edge_type=True), axis=0)
        query_feat = [G_query.node[u]['feat'] for u in G_query.nodes()]
        query_feat = np.expand_dims(np.vstack(query_feat), axis=0)
        query_labels = [query_labels]

        matching = ManualSubgraphMatching(model, adj, feat, labels, query_adj, query_feat, query_labels, nb_query, args)
        matching.train_subgraph_match(args, writer=writer)
        matching.test_subgraph_match(args, writer=writer)

def syn_dup(args, writer=None):
    # data
    G, G_query, labels, query_labels, name = gengraph.gen_syn1(feature_generator=featgen.DegreeFeatureGen(dim=args.input_dim))
            #feature_generator=featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float)))
    G, G_query, labels, query_labels, name = gengraph.gen_dup(G, G_query, labels, query_labels, name, correspond_prob=1.0)
    num_classes = 2

    adj = np.expand_dims(to_numpy_matrix(G, edge_type=True), axis=0)
    vars(args)["edge_dim"] = adj.shape[1]

    if args.method == 'no training':
        print('Method: random init')
    else:
        print('Method:', args.method)
        model = models.GcnEncoderMatching(args.input_dim, args.hidden_dim, args.output_dim, num_classes,
                                       args.num_gc_layers, bn=args.bn, args=args)
        if args.gpu:
            model = model.cuda()

        feat = [G.node[u]['feat'] for u in G.nodes()]
        feat = np.expand_dims(np.vstack(feat), axis=0)

        query_adj = np.expand_dims(to_numpy_matrix(G_query, edge_type=True), axis=0)
        query_feat = [G_query.node[u]['feat'] for u in G_query.nodes()]
        query_feat = np.expand_dims(np.vstack(query_feat), axis=0)
        query_labels = [query_labels]

        matching = SubgraphMatching(model, adj, feat, labels, query_adj, query_feat, query_labels, args)
        matching.train_subgraph_match(args=args, writer=writer)

def syn_multiple(args, writer=None):
    # data
    nb_shapes = 1 if args.method == "augm" else 100
    G, G_query, labels, query_labels, name = gengraph.gen_syn_multiple(nb_shapes = nb_shapes, width_basis = 300,
        feature_generator=featgen.DegreeFeatureGen(args.input_dim))
         #feature_generator=featgen.ConstFeatureGen(np.ones(args.input_dim, dtype=float)))
    #G, G_query, labels, query_labels, name = gengraph.gen_syn_multiple(
    #    feature_generator=featgen.DegreeFeatureGen(np.ones(args.input_dim, dtype=float)))

    num_classes = 2
    vars(args)["edge_dim"] = 1
    if args.method == 'no training':
        print('Method: random init')
    elif args.method == 'base':
        print('Method:', args.method)

        if args.order_embeddings:
            model = models.GcnEncoderMatching(args.input_dim, args.hidden_dim, args.output_dim, num_classes,
                                       args.num_gc_layers, bn=args.bn, args=args, order_embeddings=True)
        else:
            model = models.GcnEncoderMatching(args.input_dim, args.hidden_dim, args.output_dim, num_classes,
                                       args.num_gc_layers, bn=args.bn, args=args)
        if args.gpu:
            model = model.cuda()

        adj = np.expand_dims(to_numpy_matrix(G, edge_type=True), axis=0)
        feat = [G.node[u]['feat'] for u in G.nodes()]
        feat = np.expand_dims(np.vstack(feat), axis=0)

        query_adj = []
        query_feat = []
        for i, graph in enumerate(G_query):
            query_adj.append(to_numpy_matrix(graph, edge_type=True))
            feat_i = [graph.node[u]['feat'] for u in graph.nodes()]
            query_feat.append(np.vstack(feat_i))

        max_query_sizes = [max([adj.shape[i] for adj in query_adj]) for i in range(3)]
        query_adj_stack = np.zeros((len(G_query), *max_query_sizes))
        query_feat_stack = np.zeros((len(G_query), max_query_sizes[1], args.input_dim))
        for i in range(len(G_query)):
            adj_i = query_adj[i]
            feat_i = query_feat[i]
            query_adj_stack[i, :adj_i.shape[0], :adj_i.shape[1], :adj_i.shape[2]] = adj_i
            query_feat_stack[i, :feat_i.shape[0], :feat_i.shape[1]] = feat_i


        matching = SubgraphMatching(model, adj, feat, labels, query_adj_stack, query_feat_stack, query_labels, args)
        matching.train_subgraph_match(args, writer=writer)
        matching.test_subgraph_match(args, writer=writer)
    elif args.method == 'augm':
        print('Method:', args.method)

        if args.order_embeddings:
            model = models.GcnEncoderMatching(args.input_dim, args.hidden_dim, args.output_dim, num_classes,
                                       args.num_gc_layers, bn=args.bn, args=args, order_embeddings=True)
        else:
            model = models.GcnEncoderMatching(args.input_dim, args.hidden_dim, args.output_dim, num_classes,
                                       args.num_gc_layers, bn=args.bn, args=args)
        if args.gpu:
            model = model.cuda()

        adj = np.expand_dims(to_numpy_matrix(G, edge_type=True), axis=0)
        feat = [G.node[u]['feat'] for u in G.nodes()]
        feat = np.expand_dims(np.vstack(feat), axis=0)

        query_adj = []
        query_feat = []
        for i, graph in enumerate(G_query):
            query_adj.append(to_numpy_matrix(graph, edge_type=True))
            feat_i = [graph.node[u]['feat'] for u in graph.nodes()]
            query_feat.append(np.vstack(feat_i))

        max_query_sizes = [max([adj.shape[i] for adj in query_adj]) for i in range(3)]
        query_adj_stack = np.zeros((len(G_query), *max_query_sizes))
        query_feat_stack = np.zeros((len(G_query), max_query_sizes[1], args.input_dim))
        for i in range(len(G_query)):
            adj_i = query_adj[i]
            feat_i = query_feat[i]
            query_adj_stack[i, :adj_i.shape[0], :adj_i.shape[1], :adj_i.shape[2]] = adj_i
            query_feat_stack[i, :feat_i.shape[0], :feat_i.shape[1]] = feat_i


        num_rnd_query = 2

        query_adj_rnd, query_feat_rnd, root_list = get_random_subgraph(adj, feat, num_rnd_query)
        max_query_rnd_sizes = [max([adj.shape[i] for adj in query_adj_rnd]) for i in range(3)]
        query_adj_rnd_stack = np.zeros((num_rnd_query, *max_query_rnd_sizes))
        query_feat_rnd_stack = np.zeros((num_rnd_query, max_query_rnd_sizes[1], args.input_dim))
        for i in range(num_rnd_query):
            adj_i = query_adj_rnd[i]
            feat_i = query_feat_rnd[i]
            query_adj_rnd_stack[i, :adj_i.shape[0], :adj_i.shape[1], :adj_i.shape[2]] = adj_i
            query_feat_rnd_stack[i, :feat_i.shape[0], :feat_i.shape[1]] = feat_i

        matching = SubgraphMatchingAug(model, adj, feat, labels, query_adj_stack, query_feat_stack,
            query_labels, query_adj_rnd_stack, query_feat_rnd_stack, root_list, args)
        matching.visualize_graph(query_adj_rnd, writer, 0)
        matching.train_subgraph_match(args, writer=writer)
        matching.test_subgraph_match(args, writer=writer)

def get_random_subgraph(adj, feat, n_subgraphs = 10, n_hops = 1):
    """ adj: numpy array
        feat: numpy feat array num_graphs*num_nodes*dim
    """
    query_random_adj = []
    query_random_feat = []
    root_list = []
    num_nodes = adj.shape[2]
    for _ in range(n_subgraphs):
        start_node = np.random.randint(num_nodes)
        root_list.append(start_node)
        subg_idx = []
        layer = 0
        queue = [(start_node, layer)]
        while queue:
            prt_node, layer_prev = queue.pop(0)
            if layer_prev > 0 and np.random.uniform() > 0.75:
                continue
            if layer_prev > n_hops:
                break
            subg_idx.append(prt_node)
            for child_node in [i for i in range(num_nodes) if adj[0, 0, prt_node, i] != 0]:
                queue.append((child_node, layer_prev + 1))
        subg_idx = np.asarray(subg_idx)
        tmp = np.expand_dims(adj[0, 0, subg_idx, :][:, subg_idx], axis = 0)
        query_random_adj.append(tmp)
        query_random_feat.append(feat[0, subg_idx, :])

    return query_random_adj, query_random_feat, root_list

def random_task(args, writer=None):
    vars(args)["edge_dim"] = 1
    train_dataset = RandomDataset(num_queries=args.num_queries, feature_generator=featgen.DegreeCenterFeatureGen(args.input_dim),
                                  gpu=args.gpu, induced=args.induced, query_size=args.init_hops, init_size=args.init_size, phase='all')
    val_dataset = RandomDataset(num_queries=128, feature_generator=featgen.DegreeCenterFeatureGen(args.input_dim),
                                gpu=args.gpu, induced=args.induced, query_size=4, phase='center')
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=False, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, pin_memory=False)
    model = models.GcnEncoderMatching(args.input_dim, args.hidden_dim, args.output_dim, 2,
                                      args.num_gc_layers, pred_hidden_dims=[120, 32], bn=args.bn, args=args)
    if args.gpu:
        model = model.cuda()
    #train_dataset.visualize(writer)
    matching = SubgraphMatchingRandom(model, train_loader, args, writer=writer, val_loader=val_loader)
    matching.train_subgraph_match(args, writer=writer)

def random_task_pyg(args, writer=None):
    vars(args)["edge_dim"] = 1
    train_dataset = RandomDatasetPyG(num_queries=64, feature_generator=featgen.DegreeCenterFeatureGen(args.input_dim),
                                  gpu=args.gpu, induced=args.induced)
    val_dataset = RandomDatasetPyG(num_queries=32, feature_generator=featgen.DegreeCenterFeatureGen(args.input_dim),
                                gpu=args.gpu, induced=args.induced)
    train_loader = DataLoader_pyg(train_dataset, batch_size=args.batch_size, pin_memory=False, shuffle=True)
    val_loader = DataLoader_pyg(val_dataset, batch_size=args.batch_size, pin_memory=False)

    layers = [{'in_dim': args.input_dim, 'hidden_dim': args.hidden_dim, 'out_dim': args.hidden_dim},
              {'in_dim': args.hidden_dim, 'hidden_dim': args.hidden_dim, 'out_dim': args.hidden_dim},
              {'in_dim': args.hidden_dim, 'hidden_dim': args.hidden_dim, 'out_dim': args.output_dim}]
    model = models_pyg.SiameseGNN(layers, layers)
    #train_dataset.visualize(writer)
    matching = SubgraphMatchingRandomPyG(model, train_loader, args, val_loader, writer)
    start = time.time()
    matching.train_subgraph_match(args, writer=writer)
    print('Done. Took {time.time() - start} seconds')

    # save trained model; None for optimizer
    io_utils.save_checkpoint(model, None, args)

def random_basis_task(args, writer=None):
    vars(args)["edge_dim"] = 1
    if args.basis == 'siemens':
        graphs = load_data.read_siemens()
    elif args.basis == 'nn_code_sample':
        graphs, query_graphs = load_data.read_arqui(path=os.path.join(args.datadir, args.basis))
    elif args.basis == 'DD' or args.basis == 'ENZYMES' or args.basis == 'COX2' or args.basis == 'MSRC_21' or args.basis == 'FIRSTMM_DB':
        graphs = load_data.read_graphfile('data', args.basis)
    elif args.basis == 'enzymes':
        graphs = load_data.read_graphfile('data', 'ENZYMES')
    elif args.basis == 'WN':
        graphs = load_data.read_WN(path=os.path.join(args.datadir, 'WN18.gpickle'))
    elif args.basis == 'ppi':
        graphs = load_data.read_ppi()
    else:
        raise NotImplementedError
    sample = (args.basis == 'WN' or args.basis == 'ppi')
    train_dataset = RandomBasisDataset(graphs, num_queries=args.num_queries, gpu=args.gpu, init_size=args.init_size,
                                       query_size=args.init_hops, sample_neighborhoods=sample)
    val_dataset = RandomBasisDataset(graphs, num_queries=128, gpu=args.gpu, query_size=4, phase='center',
                                     sample_neighborhoods=sample)
    #val_dataset = PredefinedDataset(graphs, query_graphs, gpu=args.gpu)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=False, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, pin_memory=False)
    feat_dim = graphs[0].nodes[random.choice(list(graphs[0].nodes))]['feat'].size
    print(args.num_gc_layers)
    model = models.GcnEncoderMatching(feat_dim, feat_dim // 2, args.output_dim, 2,
                                      args.num_gc_layers, pred_hidden_dims=[120, 32], bn=args.bn, args=args)
    if args.gpu:
        model = model.cuda()
    train_dataset.visualize(writer)
    matching = SubgraphMatchingRandom(model, train_loader, args, writer=writer, val_loader=val_loader)
    matching.train_subgraph_match(args, writer=writer)

def arg_parse():
    parser = argparse.ArgumentParser(description='GraphPool arguments.')
    io_parser = parser.add_mutually_exclusive_group(required=False)
    io_parser.add_argument('--dataset', dest='dataset',
            help='Input dataset.')
    benchmark_parser = io_parser.add_argument_group()
    benchmark_parser.add_argument('--bmname', dest='bmname',
            help='Name of the benchmark dataset')
    io_parser.add_argument('--pkl', dest='pkl_fname',
            help='Name of the pkl data file')

    softpool_parser = parser.add_argument_group()
    softpool_parser.add_argument('--assign-ratio', dest='assign_ratio', type=float,
            help='ratio of number of nodes in consecutive layers')
    softpool_parser.add_argument('--num-pool', dest='num_pool', type=int,
            help='number of pooling layers')
    parser.add_argument('--linkpred', dest='linkpred', action='store_const',
            const=True, default=False,
            help='Whether link prediction side objective is used')

    parser_utils.parse_optimizer(parser)

    parser.add_argument('--basis', dest='basis', help='Basis to use for basis dataset.')
    parser.add_argument('--datadir', dest='datadir',
            help='Directory where benchmark is located')
    parser.add_argument('--logdir', dest='logdir',
            help='Tensorboard log directory')
    parser.add_argument('--ckptdir', dest='ckptdir',
            help='Model checkpoint directory')
    parser.add_argument('--cuda', dest='cuda',
            help='CUDA.')
    parser.add_argument('--gpu', dest='gpu', action='store_const',
            const=True, default=False,
            help='whether to use GPU.')
    parser.add_argument('--max-nodes', dest='max_nodes', type=int,
            help='Maximum number of nodes (ignore graghs with nodes exceeding the number.')
    parser.add_argument('--batch_size', dest='batch_size', type=int, default=8,
            help='Batch size.')
    parser.add_argument('--pos-ratio', dest='pos_ratio', type=int,
                        help='Proportion of batch to be positive.')
    parser.add_argument('--epochs', dest='num_epochs', type=int,
            help='Number of epochs to train.')
    parser.add_argument('--train-ratio', dest='train_ratio', type=float,
            help='Ratio of number of graphs training set to all graphs.')
    parser.add_argument('--num_workers', dest='num_workers', type=int,
            help='Number of workers to load data.')
    parser.add_argument('--feature', dest='feature_type',
            help='Feature used for encoder. Can be: id, deg')
    parser.add_argument('--input-dim', dest='input_dim', type=int,
            help='Input feature dimension')
    parser.add_argument('--hidden-dim', dest='hidden_dim', type=int,
            help='Hidden dimension')
    parser.add_argument('--output-dim', dest='output_dim', type=int,
            help='Output dimension')
    parser.add_argument('--num-classes', dest='num_classes', type=int,
            help='Number of label classes')
    parser.add_argument('--num-gc-layers', dest='num_gc_layers', type=int,
            help='Number of graph convolution layers before each pooling')
    parser.add_argument('--bn', dest='bn', action='store_const',
            const=True, default=False,
            help='Whether batch normalization is used')
    parser.add_argument('--dropout', dest='dropout', type=float,
            help='Dropout rate.')
    parser.add_argument('--nobias', dest='bias', action='store_const',
            const=False, default=True,
            help='Whether to add bias. Default to True.')
    parser.add_argument('--weight-decay', dest='weight_decay', type=float,
            help='Weight decay regularization constant.')
    parser.add_argument('--use_nbr', dest='use_nbr', action='store_const',
            const=True, default=False,
            help='whether to only use neighbor adjacency & features.')
    parser.add_argument('--method', dest='method',
            help='Method. Possible values: base, manual_label')
    parser.add_argument('--name-suffix', dest='name_suffix',
            help='suffix added to the output filename')
    parser.add_argument('--harder_training', dest='harder_training', default=False,
            help='Setup harder_training data')
    parser.add_argument('--order_embeddings', dest='order_embeddings', action='store_const',
            const=True, default=False,
            help='Use order embeddings to train')
    parser.add_argument('--ntn', dest='ntn', action='store_const',
            const=True, default=False,
            help='Add neural tensor network layer to the last of the prediction layers')
    parser.add_argument('--compute-dataset-recall', dest='compute_dataset_recall', action='store_const',
            const=True, default=False,
            help='Compute entire dataset recall')
    parser.add_argument('--withhold', type=int, default=0,
            help="Number of query graphs to withhold from training (synmultiple only)")
    parser.add_argument('--init_size', type=int, default=1,
            help='Number of queries to start with in curriculum')
    parser.add_argument('--init_hops', type=int, default=3,
                        help='Query size to start with in curriculum')
    parser.add_argument('--opt_scheduler', choices=['step', 'cos', 'cosine_schedule'])
    parser.add_argument('--opt_restart', type=int)
    parser.add_argument('--induced', action='store_true')
    parser.add_argument('--graph_embeddings', action='store_true')
    parser.add_argument('--num_queries', type=int)
    parser.add_argument('--multi', action='store_true')
    parser.add_argument('--lower', type=int)
    parser.add_argument('--upper', type=int)

    parser.set_defaults(datadir='data', # io_parser
                        logdir='log',
                        ckptdir='ckpt',
                        dataset='syn1',
                        opt='adam',   # opt_parser
                        opt_scheduler='none',
                        max_nodes=100,
                        cuda='1',
                        feature_type='default',
                        lr=0.005,
                        clip=2.0,
                        batch_size=80,
                        num_epochs=1000,
                        train_ratio=0.8,
                        test_ratio=0.1,
                        num_workers=1,
                        input_dim=10,
                        hidden_dim=20,
                        output_dim=20,
                        num_classes=2,
                        num_gc_layers=3,
                        dropout=0.0,
                        weight_decay=0.005,
                        method='base',
                        name_suffix='',
                        assign_ratio=0.1,
                        pos_ratio=0.5
                       )
    return parser.parse_args()

def main():
    prog_args = arg_parse()

    # export scalar data to JSON for external processing
    path = os.path.join(prog_args.logdir, io_utils.gen_prefix(prog_args))
    # if os.path.isdir(path):
    #     print('Remove existing log dir: ', path)
    #     shutil.rmtree(path)
    writer = SummaryWriter(path)
    #writer = None

    if prog_args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = prog_args.cuda
        print('CUDA', prog_args.cuda)
    else:
        print('Using CPU')

    if prog_args.dataset is not None and not prog_args.multi:
        if prog_args.dataset == 'syn1':
            syn_task1(prog_args, writer=writer)
        elif prog_args.dataset == 'syndup':
            syn_dup(prog_args, writer=writer)
        elif prog_args.dataset == 'synmultiple':
            syn_multiple(prog_args, writer=writer)
        elif prog_args.dataset == 'random':
            random_task(prog_args, writer=writer)
        elif prog_args.dataset == 'basis':
            random_basis_task(prog_args, writer=writer)
        elif prog_args.dataset == 'random_pyg':
            random_task_pyg(prog_args, writer)

    elif prog_args.multi:
        logdir = prog_args.logdir
        for num_conv_layers in range(prog_args.lower, prog_args.upper):
            prog_args.logdir = logdir + str(num_conv_layers) + '_conv'
            writer = SummaryWriter(prog_args.logdir)
            prog_args.num_gc_layers = num_conv_layers
            random_basis_task(prog_args, writer=writer)
    writer.close()

if __name__ == "__main__":
    main()
