import argparse
import os

import numpy as np
import torch
from dgl.data import split_dataset
from dgl.dataloading import GraphDataLoader
from dgl.nn.pytorch import GraphConv, GATConv, SAGEConv
import torch.nn.functional as F

from dataset_loader.loader import dataset_loader
from dataset_loader.metrics import test_error_QM7b, test_error_DBLP, accuracy_error_DBLP, test_error_QM9
from dataset_loader.utils import filter_nb_nodes, in_feats, separate_by_nodes, out_feats
from dataset import CustomDataset, CustomDatasetClassical, build_datasets
from models.GNN import GNN, train_model
from models.quantum import QGraphNetworkCustom, MultiHeadQGraphModel, HybridModel
from utils import generate_ising_matrices_torch, obs_ZZ, compute_all_ising_matrices
from training_quantum import training_loop_parallel, training_loop_single

# ------------------------- DEFINE PARAMETERS
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', '-d', choices=['QM7', 'DBLP_v1', 'Letter-med', 'QM9', 'Cover', 'CoverGen', 'ZINC'], type=str, required=True)
parser.add_argument('--min_node', type=int, default=4)
parser.add_argument('--max_node', type=int, default=15)
parser.add_argument('--permutation_seed', '-seed', type=int, default=0)
parser.add_argument('--model', '-m', choices=['QGNN', 'SAGE', 'GCN', 'GAT', 'Hybrid'], type=str, required=True)
parser.add_argument('--layers', '-l', nargs='+', default=['128'])
parser.add_argument('--hyperparameters', '-p', nargs='+', default=[])
parser.add_argument('--splits', nargs='+', type=float, default=(.8, .1, .1))
parser.add_argument('--n_epochs', type=int, default=20)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--gamma', type=float, default=1.)
parser.add_argument('--n_processes', '-nprocs', type=int, default=1)
# parser.add_argument('--fast', '-f', type=bool, default=False, action=argparse.BooleanOptionalAction) needs python 3.9 not yet in the cluster
parser.add_argument('--fast', '-f', type=bool, default=False)
parser.add_argument('--wandb', '-wb', type=bool, default=False)
parser.add_argument('--quantum_every', '-q', type=int, default=4)
parser.add_argument('--dataset_parameters', '-dp', nargs='+', default=[])


if __name__ == '__main__':
    # ------------------------- DEFINE OPTIONS
    layers_models = {
        'SAGE': SAGEConv,
        'GAT': GATConv,
        'GCN': GraphConv
    }
    error_functions = {
        'QM7': test_error_QM7b,
        'DBLP_v1': test_error_DBLP,
        'Letter-med': test_error_DBLP,
        'QM9': test_error_QM9,
        'Cover': test_error_DBLP,
        'CoverGen': test_error_DBLP,
        'ZINC': test_error_QM9,
    }

    loss_functions = {
        'QM7': F.l1_loss,
        'DBLP_v1': F.cross_entropy,
        'Letter-med': F.cross_entropy,
        'Cover': F.cross_entropy,
        'CoverGen': F.cross_entropy,
        'QM9': F.l1_loss,
        'ZINC': F.l1_loss
    }

    loss_functions_test = {
        'QM7': F.l1_loss,
        'DBLP_v1': accuracy_error_DBLP,
        'Letter-med': accuracy_error_DBLP,
        'Cover': accuracy_error_DBLP,
        'CoverGen': accuracy_error_DBLP,
        'QM9': F.l1_loss,
        'ZINC': F.l1_loss
    }

    out_dims = {
        'QM7': None,
        'DBLP_v1': 2,
        'Letter-med': 15,
        'QM9': None,
        'Cover': 2,
        'CoverGen': 3,
        'ZINC': None
    }

    # default params in functions should be based on legacy usage
    # -> define desired default params here so they are stored in W&B
    dataset_default_params = {
        'QM7': [False, True]  # don't use positions & use embed atomic number Z
    }

    # ------------------------- COLLECT PARAMETERS
    # --- WARNING: Keep values in args when update as it's the variable eventually sent to W&B
    args = parser.parse_args()

    args.hyperparameters = [int(p) if p.isdigit() else p for p in args.hyperparameters]
    args.layers = [int(p) if p.isdigit() else p for p in args.layers]

    # Setting default dataset parameters for QM7
    if args.dataset in dataset_default_params:
        default_params = dataset_default_params[args.dataset]
        # Take provided arguments and use default for remaining
        args.dataset_parameters = [
            args.dataset_parameters[i]
            if i < len(args.dataset_parameters)
            else default_params[i]
            for i in range(len(default_params))
        ]

    print(args)

    train_size, val_size, test_size = args.splits
    assert train_size + val_size + test_size == 1.

    # ------------------------- MAIN LOGIC
    dataset = dataset_loader(args.dataset, *args.dataset_parameters, min_node=args.min_node, max_node=args.max_node)
    batch_size = 500

    print(dataset[0])

    if args.fast:
        dataset.__len__ = lambda self: 100  # doesn't work...
    if args.dataset == 'ZINC':
        train_dataset = dataset_loader(args.dataset, *args.dataset_parameters, min_node=args.min_node, max_node=args.max_node, split='train')
        val_dataset = dataset_loader(args.dataset, *args.dataset_parameters, min_node=args.min_node, max_node=args.max_node, split='val')
        test_dataset = dataset_loader(args.dataset, *args.dataset_parameters, min_node=args.min_node, max_node=args.max_node, split='test')
    elif args.dataset == 'CoverGen':
        train_dataset = dataset_loader(args.dataset, *args.dataset_parameters, min_node=args.min_node, max_node=args.max_node)
        val_dataset = dataset_loader(args.dataset, *args.dataset_parameters, min_node=args.min_node, max_node=args.max_node)
        test_dataset = dataset_loader(args.dataset, *args.dataset_parameters, min_node=args.min_node, max_node=args.max_node)
    else:
        train_dataset, val_dataset, test_dataset = split_dataset(
            dataset, frac_list=args.splits, shuffle=True, random_state=args.permutation_seed
        )

    if args.model not in ['QGNN', 'Hybrid']:
        # Get related options
        conv = layers_models[args.model]
        loss_function = loss_functions[args.dataset]
        error_function = error_functions[args.dataset]

        # Logic
        train_dataloader = GraphDataLoader(train_dataset, batch_size=batch_size, drop_last=False)
        val_dataloader = GraphDataLoader(val_dataset, batch_size=batch_size, drop_last=False)
        test_dataloader = GraphDataLoader(test_dataset, batch_size=batch_size, drop_last=False)

        model = GNN(conv, in_feats(dataset), args.layers, out_feats(dataset), *args.hyperparameters, flatten_labels=False)
        os.environ["WANDB_MODE"] = 'offline' if args.fast else 'online'
        train_model(model, train_dataloader, loss_function, lr=args.lr, gamma=args.gamma, n_epochs=args.n_epochs,
                    error_function=error_function, test_dataloader=test_dataloader, val_dataloader=val_dataloader,
                    dataset_name=args.dataset, seed=args.permutation_seed, wb=args.wandb, cli_args=args.__dict__)
    else:
        train_graphs, train_targets = list(zip(*train_dataset))
        val_graphs, val_targets = [], []
        if len(val_dataset) > 0:
            val_graphs, val_targets = list(zip(*val_dataset))
        test_graphs, test_targets = [], []
        if len(test_dataset) > 0:
            test_graphs, test_targets = list(zip(*test_dataset))

        train_graphs_list, train_targets_list, nodes_train = separate_by_nodes(train_graphs, args.min_node, args.max_node, train_targets)
        val_graphs_list, val_targets_list, nodes_val = separate_by_nodes(val_graphs, args.min_node, args.max_node, val_targets)
        test_graphs_list, test_targets_list, nodes_test = separate_by_nodes(test_graphs, args.min_node, args.max_node, test_targets)

        print(len(dataset))

        NN_matrices = dict()
        for n in range(2, args.max_node+1):
            matrices = dict()
            for i in range(n):
                for j in range(i, n):
                    matrix = obs_ZZ(n, i, j, type_n=True)
                    matrices[(i, j)] = matrix
            NN_matrices[n] = matrices
            del matrices

        train_ising_matrices_list = compute_all_ising_matrices(train_graphs_list, NN_matrices)
        val_ising_matrices_list = compute_all_ising_matrices(val_graphs_list, NN_matrices)
        test_ising_matrices_list = compute_all_ising_matrices(test_graphs_list, NN_matrices)

        attention_params = list(map(int, args.hyperparameters))[1:3]

        dataset_train_list, dataset_train_classical_list = build_datasets(train_graphs_list, train_targets_list, train_ising_matrices_list, nodes_train, attention_params, 'attr', out_dim=out_dims[args.dataset])
        dataset_val_list, dataset_val_classical_list = build_datasets(val_graphs_list, val_targets_list, val_ising_matrices_list, nodes_val, attention_params, 'attr', out_dim=out_dims[args.dataset])
        dataset_test_list, dataset_test_classical_list = build_datasets(test_graphs_list, test_targets_list, test_ising_matrices_list, nodes_test, attention_params, 'attr', out_dim=out_dims[args.dataset])


        batch_sizes = np.ones(25).astype(int)
        batch_sizes[1:15] = 360
        batch_sizes[15::] = np.array([360, 200, 100, 50, 30, 15, 6, 3, 2, 1]).astype(int)

        batch_sizes_train = batch_sizes[nodes_train]
        batch_sizes_val = []
        if len(nodes_val) > 0:
            batch_sizes_val = batch_sizes[nodes_val]
        batch_sizes_test = []
        if len(nodes_test) > 0:
            batch_sizes_test = batch_sizes[nodes_test]

        # Take the corresponding input and output shape
        hyperparameters = [dataset_train_list[0].in_dim, dataset_train_list[0].out_dim] + list(map(int, list(args.hyperparameters)))
        print(hyperparameters)

        if args.model == 'QGNN':
            model = MultiHeadQGraphModel(*hyperparameters)
        else:
            model = HybridModel(*hyperparameters)

        print('Preprocessing done')
        for dataset_train, n in zip(dataset_train_list, nodes_train):
            print(n, len(dataset_train))

        if args.n_processes > 1:
            training_loop_parallel(model, dataset_train_list, dataset_train_classical_list, args.n_epochs, args.n_processes, batch_sizes_train, batch_sizes_val, batch_sizes_test,
                                   args.quantum_every, dataset_val_list, dataset_val_classical_list, dataset_test_list, dataset_test_classical_list,
                                   lr=args.lr, gamma=args.gamma, dataset_name=args.dataset, wb=args.wandb, cli_args=args.__dict__, loss_func=loss_functions[args.dataset], loss_func_test=loss_functions_test[args.dataset])
        else:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            training_loop_single(model, dataset_train_list, dataset_train_classical_list, args.n_epochs, batch_sizes_train, batch_sizes_val, batch_sizes_test, device, args.quantum_every,
                                 dataset_val_list, dataset_val_classical_list, dataset_test_list, dataset_test_classical_list, lr=args.lr, gamma=args.gamma, dataset_name=args.dataset,
                                 wb=args.wandb, cli_args=args.__dict__, seed=args.permutation_seed, loss_func=loss_functions[args.dataset], loss_func_test=loss_functions_test[args.dataset])
