import argparse
import os

import numpy as np
import torch
from dgl.data import split_dataset
from dgl.dataloading import GraphDataLoader
from dgl.nn.pytorch import GraphConv, GATConv, SAGEConv
import torch.nn.functional as F

from dataset_loader.loader import dataset_loader
from dataset_loader.metrics import test_error_QM7b, test_error_DBLP, accuracy_error_DBLP, test_error_QM9
from dataset_loader.utils import filter_nb_nodes, in_feats, separate_by_nodes, out_feats
from dataset import CustomDataset, CustomDatasetClassical, build_datasets
from models.GNN import GNN, train_model
from models.quantum import QGraphNetworkCustom, MultiHeadQGraphModel
from utils import generate_ising_matrices_torch, obs_ZZ, compute_all_ising_matrices
from training_quantum import training_loop_parallel, training_loop_single
from torch.utils.data import DataLoader
import uuid
import pickle
import time

import torch.multiprocessing as mp


# ------------------------- DEFINE PARAMETERS
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', '-d', choices=['QM7', 'DBLP_v1', 'Letter-med', 'QM9', 'Cover', 'ZINC'], type=str, required=True)
parser.add_argument('--min_node', type=int, default=4)
parser.add_argument('--max_node', type=int, default=15)
parser.add_argument('--permutation_seed', '-seed', type=int, default=0)
parser.add_argument('--splits', nargs='+', type=float, default=(.8, .1, .1))
parser.add_argument('--n_processes', '-nprocs', type=int, default=1)
parser.add_argument('--dataset_parameters', '-dp', nargs='+', default=[])



def return_dataloaders(dataset_list, batch_sizes):
    dataloader_list = []
    dataloader_classical_list = []
    assert len(dataset_list) == len(batch_sizes)

    num_items = 0

    for dataset, batch_size in zip(dataset_list, batch_sizes):

        dataloader = DataLoader(
                dataset, batch_size=int(batch_size), drop_last=False, shuffle=False)


        dataloader_list.append(dataloader)
        num_items += len(dataset)

    return dataloader_list, num_items


def generate_matrices(rank, model, datasets_train, datasets_val_list, datasets_test_list, batch_sizes_train, batch_sizes_val, batch_sizes_test, n_iter=20, dataset_name='QM9', seed=0):
    device = f'cuda:{rank}' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    model = model.to(device)

    dataloader_list,  num_train = return_dataloaders(datasets_train, batch_sizes_train)

    if datasets_val_list is not None:
        dataloader_val_list,  num_val = return_dataloaders(datasets_val_list, batch_sizes_val)

    if datasets_test_list is not None:
        dataloader_test_list,  num_test = return_dataloaders(datasets_test_list, batch_sizes_test)

    torch.random.manual_seed(97 * rank + 1)

    with torch.no_grad():

        for iter in range(n_iter):

            model.reset_parameters()
            exp_id = uuid.uuid4().hex[:10]
            output_file = f'{dataset_name}/{exp_id}'

            if not os.path.exists(f'{dataset_name}'):
                os.mkdir(f'{dataset_name}')

            config = {
                    "n_layers": 4,
                    "random_seed": seed
                }

            os.mkdir(output_file)
            os.mkdir(f'{output_file}/train')
            os.mkdir(f'{output_file}/val')
            os.mkdir(f'{output_file}/test')

            with open(f'{output_file}/config.pickle', 'wb') as f:
                pickle.dump(config, f)
            torch.save(model.state_dict(), f'{output_file}/parameters.pt')

            for dataloader, dataset in zip(dataloader_list, datasets_train):
                matrices = []
                t0 = time.perf_counter()
                for features, labels, ising, adj, idx in dataloader:
                    mat = model.layers[0].heads[0].attention(features.shape[1], ising.to(device), batch_size=labels.shape[0], return_raw_matrices=True)
                    mat = mat.permute(3, 0, 1, 2)
                    matrices.append(mat.cpu())
                t1 = time.perf_counter()
                torch.save(torch.cat(matrices), f'{output_file}/train/matrices_{dataset.n_nodes}.pt')
                if iter == 0:
                    print(dataset.n_nodes, t1 - t0)

            for dataloader, dataset in zip(dataloader_val_list, datasets_val_list):
                matrices = []
                t0 = time.perf_counter()
                for features, labels, ising, adj, idx in dataloader:
                    mat = model.layers[0].heads[0].attention(features.shape[1], ising.to(device), batch_size=labels.shape[0], return_raw_matrices=True)
                    mat = mat.permute(3, 0, 1, 2)
                    matrices.append(mat.cpu())
                t1 = time.perf_counter()
                torch.save(torch.cat(matrices), f'{output_file}/val/matrices_{dataset.n_nodes}.pt')

            for dataloader, dataset in zip(dataloader_test_list, datasets_test_list):
                matrices = []
                t0 = time.perf_counter()
                for features, labels, ising, adj, idx in dataloader:
                    mat = model.layers[0].heads[0].attention(features.shape[1], ising.to(device), batch_size=labels.shape[0], return_raw_matrices=True)
                    mat = mat.permute(3, 0, 1, 2)
                    matrices.append(mat.cpu())
                t1 = time.perf_counter()
                torch.save(torch.cat(matrices), f'{output_file}/test/matrices_{dataset.n_nodes}.pt')



if __name__ == '__main__':
    # ------------------------- DEFINE OPTIONS


    layers_models = {
        'SAGE': SAGEConv,
        'GAT': GATConv,
        'GCN': GraphConv
    }
    error_functions = {
        'QM7': test_error_QM7b,
        'DBLP_v1': test_error_DBLP,
        'Letter-med': test_error_DBLP,
        'QM9': test_error_QM9,
        'Cover': test_error_DBLP,
        'ZINC': test_error_QM9,
    }

    loss_functions = {
        'QM7': F.l1_loss,
        'DBLP_v1': F.cross_entropy,
        'Letter-med': F.cross_entropy,
        'Cover': F.cross_entropy,
        'QM9': F.l1_loss,
        'ZINC': F.l1_loss
    }

    loss_functions_test = {
        'QM7': F.l1_loss,
        'DBLP_v1': accuracy_error_DBLP,
        'Letter-med': accuracy_error_DBLP,
        'Cover': accuracy_error_DBLP,
        'QM9': F.l1_loss,
        'ZINC': F.l1_loss
    }

    out_dims = {
        'QM7': None,
        'DBLP_v1': 2,
        'Letter-med': 15,
        'QM9': None,
        'Cover': 2,
        'ZINC': None
    }

    # default params in functions should be based on legacy usage
    # -> define desired default params here so they are stored in W&B
    dataset_default_params = {
        'QM7': [False, True]  # don't use positions & use embed atomic number Z
    }

    # ------------------------- COLLECT PARAMETERS
    # --- WARNING: Keep values in args when update as it's the variable eventually sent to W&B
    args = parser.parse_args()


    train_size, val_size, test_size = args.splits
    assert train_size + val_size + test_size == 1.

    # ------------------------- MAIN LOGIC
    dataset = dataset_loader(args.dataset, min_node=args.min_node, max_node=args.max_node)
    batch_size = 500

    if args.dataset == 'ZINC':
        train_dataset = dataset_loader(args.dataset, *args.dataset_parameters, min_node=args.min_node, max_node=args.max_node, split='train')
        val_dataset = dataset_loader(args.dataset, *args.dataset_parameters, min_node=args.min_node, max_node=args.max_node, split='val')
        test_dataset = dataset_loader(args.dataset, *args.dataset_parameters, min_node=args.min_node, max_node=args.max_node, split='test')
    else:
        train_dataset, val_dataset, test_dataset = split_dataset(
            dataset, frac_list=args.splits, shuffle=True, random_state=args.permutation_seed
        )


    train_graphs, train_targets = list(zip(*train_dataset))
    val_graphs, val_targets = list(zip(*val_dataset))
    test_graphs, test_targets = list(zip(*test_dataset))


    train_graphs_list, train_targets_list, nodes_train = separate_by_nodes(train_graphs, args.min_node, args.max_node, train_targets)
    val_graphs_list, val_targets_list, nodes_val = separate_by_nodes(val_graphs, args.min_node, args.max_node, val_targets)
    test_graphs_list, test_targets_list, nodes_test = separate_by_nodes(test_graphs, args.min_node, args.max_node, test_targets)


    print(len(dataset))

    NN_matrices = dict()
    for n in range(2, args.max_node+1):
        matrices = dict()
        for i in range(n):
            for j in range(i, n):
                matrix = obs_ZZ(n, i, j, type_n=True)
                matrices[(i, j)] = matrix
        NN_matrices[n] = matrices
        del matrices

    train_ising_matrices_list = compute_all_ising_matrices(train_graphs_list, NN_matrices)
    val_ising_matrices_list = compute_all_ising_matrices(val_graphs_list, NN_matrices)
    test_ising_matrices_list = compute_all_ising_matrices(test_graphs_list, NN_matrices)

    # attention_params = list(map(int, args.hyperparameters))[1:3]

    dataset_train_list, dataset_train_classical_list = build_datasets(train_graphs_list, train_targets_list, train_ising_matrices_list, nodes_train, (1, 1), 'attr', out_dim=out_dims[args.dataset])
    dataset_val_list, dataset_val_classical_list = build_datasets(val_graphs_list, val_targets_list, val_ising_matrices_list, nodes_val, (1, 1), 'attr', out_dim=out_dims[args.dataset])
    dataset_test_list, dataset_test_classical_list = build_datasets(test_graphs_list, test_targets_list, test_ising_matrices_list, nodes_test, (1, 1), 'attr', out_dim=out_dims[args.dataset])


    batch_sizes = np.ones(22).astype(int)
    batch_sizes[1:15] = 360
    batch_sizes[15::] = np.array([360, 200, 100, 50, 30, 15, 6]).astype(int)

    batch_sizes = batch_sizes * 10

    batch_sizes_train = batch_sizes[nodes_train]
    batch_sizes_val = batch_sizes[nodes_val]
    batch_sizes_test = batch_sizes[nodes_test]


    #Take the corresponding input and output shape


    model = MultiHeadQGraphModel(1, 1, 1, 1, 1, 4)

    print('Preprocessing done')
    for dataset_train, n in zip(dataset_train_list, nodes_train):
        print(n, len(dataset_train))

    mp.spawn(generate_matrices,
        args=(model, dataset_train_list, dataset_val_list, dataset_test_list, batch_sizes_train, batch_sizes_val, batch_sizes_test, 20, args.dataset, args.permutation_seed),
        nprocs=args.n_processes,
        join=True)