import argparse
import os

import numpy as np
import torch
import torch.nn as nn
from dgl.data import split_dataset
from dgl.dataloading import GraphDataLoader
from dgl.nn.pytorch import GraphConv, GATConv, SAGEConv
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim


from dataset_loader.loader import dataset_loader
from dataset_loader.metrics import test_error_QM7b, test_error_DBLP, accuracy_error_DBLP, test_error_QM9
from dataset_loader.utils import filter_nb_nodes, in_feats, separate_by_nodes, out_feats
from dataset import CustomDataset, CustomDatasetClassical
from models.GNN import GNN, train_model
from models.quantum import QGraphNetworkCustom, MultiHeadQGraphModel
from utils import generate_ising_matrices_torch, obs_ZZ, compute_all_ising_matrices, masked_softmax_batch
from training_quantum import training_loop_parallel, training_loop_single
from torch.utils.data import DataLoader
import uuid
import pickle
import time
import wandb
from torch.utils.data import Dataset

from dataset_loader.metrics import test_error_QM7b, test_error_DBLP, accuracy_error_DBLP, test_error_QM9



import torch.multiprocessing as mp


# ------------------------- DEFINE PARAMETERS
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', '-d', choices=['QM7', 'DBLP_v1', 'Letter-med', 'QM9', 'Cover', 'ZINC'], type=str, required=True)
parser.add_argument('--min_node', type=int, default=4)
parser.add_argument('--max_node', type=int, default=15)
parser.add_argument('--hyperparameters', '-p', nargs='+', default=[])
parser.add_argument('--permutation_seed', '-seed', type=int, default=0)
parser.add_argument('--splits', nargs='+', type=float, default=(.8, .1, .1))
parser.add_argument('--n_processes', '-nprocs', type=int, default=1)
parser.add_argument('--n_epochs', type=int, default=500)
parser.add_argument('--dataset_parameters', '-dp', nargs='+', default=[])





class RandomDatasetClassical(Dataset):

    def __init__(self, raw_graphs, raw_labels, matrices, n_nodes, feat_name='feat', out_dim=None):
        super(RandomDatasetClassical, self).__init__()
        self.features = torch.cat([graph.ndata[feat_name].clone().unsqueeze(0) for graph in raw_graphs]).float()
        self.labels = torch.tensor(raw_labels).clone()

        self.n_nodes = n_nodes
        self.matrices = matrices.clone()
        self.in_dim = self.features.shape[2]
        if out_dim is None:
            self.out_dim = self.labels.shape[1]
        else:
            self.out_dim  = out_dim


    def __getitem__(self, idx):
        """ Get elements of the dataset by index

        Parameters
        ----------
        idx : int
            Item index

        Returns
        -------
        (features, labels, adjacency_matrix, attention_matrix, index)
        (Tensor, Tensor, Tensor, Tensor, int)
        shape of features: (batch_size, n_node, features_dim)
        shape of labels: (batch_size, labels_dim)
        shape of adjacency_matrix: (batch_size, n_node, n_node)
        shape of attention_matrix: (batch_size, 2, n_node, n_node)
        """
        return self.features[idx], self.labels[idx], self.matrices[idx]

    def __len__(self):
        """Number of graphs in the dataset"""
        return self.labels.shape[0]


def build_datasets(graphs_list, targets_list, matrices_list, nodes_list, feat_name='feat', out_dim=None):
    assert len(graphs_list) == len(nodes_list)
    dataset_list = []
    for graphs, targets, matrices, n in zip(graphs_list, targets_list, matrices_list, nodes_list):
        dataset = RandomDatasetClassical(graphs, targets, matrices, n, feat_name=feat_name, out_dim=out_dim)
        dataset_list.append(dataset)
    return dataset_list

def read_matrices(folders, nodes, split='train', dataset_name='QM7'):
    print(folders)
    all_matrices = []
    for n in nodes:
        matrices = []
        for folder in folders:
            m = torch.load(f'/{dataset_name}/{folder}/{split}/matrices_{n}.pt')
            matrices.append(m.unsqueeze(1))
        all_matrices.append(torch.cat(matrices, axis=1))
    return all_matrices

def return_dataloaders(dataset_list):
    dataloader_list = []
    num_items = 0
    for dataset in dataset_list:
        dataloader = DataLoader(
                dataset, batch_size=1000, drop_last=False, shuffle=False)
        dataloader_list.append(dataloader)
        num_items += len(dataset)

    return dataloader_list, num_items


class Model(nn.Module):

    def __init__(self, n_matrices, hyperparameters):
        super(Model, self).__init__()
        self.gamma = nn.Parameter(torch.randn((n_matrices, 9)))
        self.qgnn_model = MultiHeadQGraphModel(*hyperparameters, apply_softmax=True)
        self.args = hyperparameters

    def forward(self,
                in_feat,
                matrices,
                target_att_shape):
        '''
        precomputed_attention of shape (batch_size, n_layers, n_heads, n_nodes, n_nodes)
        '''
        gamma = torch.clip(self.gamma, -3, 3)
        batch_size = matrices.shape[0]
        N = matrices.shape[2]
        attention_matrices = []
        for i, g in enumerate(gamma):
            attention_matrix = torch.sum(matrices[:, i, :, :, :] * g.expand(batch_size, N, N, -1), axis=3).unsqueeze(1)
            attention_matrices.append(attention_matrix)
        attention_matrices = torch.cat(attention_matrices, axis=1).reshape([batch_size] + list(target_att_shape) +[N, N])
        mask = torch.eye(N).unsqueeze(0).unsqueeze(0).unsqueeze(0).to(attention_matrices.device)
        mask = 1 - mask.expand(batch_size, *target_att_shape, -1, -1)
        attention_matrices = torch.exp(attention_matrices) * mask
        attention_matrices /= (torch.sum(attention_matrices, dim=4).unsqueeze(4).expand(-1, -1, -1, -1, N) + 1e-15)
       # print(torch.sum(attention_matrices, dim=4))
        return self.qgnn_model(in_feat, None, None, batch_size, precomputed_attention=attention_matrices)

    def reset_parameters(self):
        self.qgnn_model.reset_parameters()
        # init.normal_(self.gamma, 0, 1)



if __name__ == '__main__':
    # ------------------------- DEFINE OPTIONS
    layers_models = {
        'SAGE': SAGEConv,
        'GAT': GATConv,
        'GCN': GraphConv
    }
    error_functions = {
        'QM7': test_error_QM7b,
        'DBLP_v1': test_error_DBLP,
        'Letter-med': test_error_DBLP,
        'QM9': test_error_QM9,
        'Cover': test_error_DBLP,
        'ZINC': test_error_QM9,
    }

    loss_functions = {
        'QM7': F.l1_loss,
        'DBLP_v1': F.cross_entropy,
        'Letter-med': F.cross_entropy,
        'Cover': F.cross_entropy,
        'QM9': F.l1_loss,
        'ZINC': F.l1_loss
    }

    loss_functions_test = {
        'QM7': F.l1_loss,
        'DBLP_v1': accuracy_error_DBLP,
        'Letter-med': accuracy_error_DBLP,
        'Cover': accuracy_error_DBLP,
        'QM9': F.l1_loss,
        'ZINC': F.l1_loss
    }

    out_dims = {
        'QM7': None,
        'DBLP_v1': 2,
        'Letter-med': 15,
        'QM9': None,
        'Cover': 2,
        'ZINC': None
    }

    # default params in functions should be based on legacy usage
    # -> define desired default params here so they are stored in W&B
    dataset_default_params = {
        'QM7': [False, True]  # don't use positions & use embed atomic number Z
    }

    # ------------------------- COLLECT PARAMETERS
    # --- WARNING: Keep values in args when update as it's the variable eventually sent to W&B
    args = parser.parse_args()


    train_size, val_size, test_size = args.splits
    assert train_size + val_size + test_size == 1.

    # ------------------------- MAIN LOGIC
    dataset = dataset_loader(args.dataset, min_node=args.min_node, max_node=args.max_node)
    batch_size = 500

    if args.dataset == 'ZINC':
        train_dataset = dataset_loader(args.dataset, *args.dataset_parameters, min_node=args.min_node, max_node=args.max_node, split='train')
        val_dataset = dataset_loader(args.dataset, *args.dataset_parameters, min_node=args.min_node, max_node=args.max_node, split='val')
        test_dataset = dataset_loader(args.dataset, *args.dataset_parameters, min_node=args.min_node, max_node=args.max_node, split='test')
    else:
        train_dataset, val_dataset, test_dataset = split_dataset(
            dataset, frac_list=args.splits, shuffle=True, random_state=args.permutation_seed
        )


    train_graphs, train_targets = list(zip(*train_dataset))
    val_graphs, val_targets = list(zip(*val_dataset))
    test_graphs, test_targets = list(zip(*test_dataset))


    train_graphs_list, train_targets_list, nodes_train = separate_by_nodes(train_graphs, args.min_node, args.max_node, train_targets)
    val_graphs_list, val_targets_list, nodes_val = separate_by_nodes(val_graphs, args.min_node, args.max_node, val_targets)
    test_graphs_list, test_targets_list, nodes_test = separate_by_nodes(test_graphs, args.min_node, args.max_node, test_targets)

    print(nodes_test)
    print(test_graphs_list[0])

    folders = os.listdir(f'{args.dataset}')
    folders = [folder for folder in folders if os.path.exists(f'{args.dataset}/{folder}/test/matrices_9.pt')]
    print(len(folders))
    print(nodes_test)
    print(test_graphs_list[0])


    matrices_train = read_matrices(folders[0:64], nodes_train, dataset_name=args.dataset)
    matrices_val = read_matrices(folders[0:64], nodes_val, 'val', dataset_name=args.dataset)
    matrices_test = read_matrices(folders[0:64], nodes_test, 'test', dataset_name=args.dataset)

    datasets_train = build_datasets(train_graphs_list, train_targets_list, matrices_train, nodes_train, 'attr', out_dim=out_dims[args.dataset])
    datasets_val = build_datasets(val_graphs_list, val_targets_list, matrices_val, nodes_val, 'attr', out_dim=out_dims[args.dataset])
    datasets_test = build_datasets(test_graphs_list, test_targets_list, matrices_test, nodes_test, 'attr', out_dim=out_dims[args.dataset])

    dataloader_train_list, num_train = return_dataloaders(datasets_train)
    dataloader_val_list, num_val  = return_dataloaders(datasets_val)
    dataloader_test_list, num_test = return_dataloaders(datasets_test)

    n_epochs = args.n_epochs

    hyperparameters = [datasets_train[0].in_dim, datasets_train[0].out_dim] + list(map(int, list(args.hyperparameters))) #+ ['random', 64]
    print(hyperparameters)
    print(datasets_train[0].matrices.shape)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    model = Model(64, hyperparameters).to(device)

    loss_func = loss_functions[args.dataset]
    loss_func_test = loss_functions_test[args.dataset]

    optimizer = optim.Adam(model.parameters(), lr=0.0001)

    exp_id = uuid.uuid4().hex[:10]

    name = f"QGNNrandom_{args.dataset}"
    config={
                "lr": 0.0001,
                "gamma": 1,
                "n_epochs": n_epochs,
                "dataset": args.dataset,
                #"batch_size": train_dataloader.dataloader.batch_size,
                "model": 'QGNN',
                "args": model.args,
                "seed": args.permutation_seed,
                "random": True,
                "cli_args": args.__dict__
            }

    # file_output = f'results/{name}_{exp_id}'
    # os.mkdir(file_output)

    # with open(f'{file_output}/config.pickle', 'wb') as f:
    #     pickle.dump(config, f)

    wandb.init(
            project="QGNN",
            entity="qgnn",
            name=f"{name}_{exp_id}",
            config=config
        )
    wandb.watch(model)
    for epoch in range(n_epochs):
        t0 = time.perf_counter()
        for dataloader in dataloader_train_list:
            for features, labels, matrices in dataloader:
                # if not (torch.isnan(model.gamma).any()):
                #     print(model.gamma)
                pred = model(features.to(device), matrices.to(device), (8, 8))
                labels = labels.to(device)
                try:
                    loss = loss_func(pred, labels, reduction='mean')
                except RuntimeError as e:
                    labels = labels.reshape((-1,))
                    loss = loss_func(pred, labels, reduction='mean')
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                #print(pred)
                #print(model.qgnn_model.layers[0].heads[0].linear.weight)
                del labels
                del pred
                torch.cuda.empty_cache()
        t1 = time.perf_counter()
        with torch.no_grad():
            loss_train = 0
            for dataloader in dataloader_train_list:
                for features, labels, matrices in dataloader:
                    pred = model(features.to(device), matrices.to(device), (8, 8))
                    labels = labels.to(device)
                    try:
                        loss = loss_func(pred, labels, reduction='sum')
                    except RuntimeError as e:
                        labels = labels.reshape((-1,))
                        loss = loss_func(pred, labels, reduction='sum')
                    loss_train += loss / num_train
                    del labels
                    del pred
                    torch.cuda.empty_cache()

            loss_val = 0
            for dataloader in dataloader_val_list:
                for features, labels, matrices in dataloader:
                    pred = model(features.to(device), matrices.to(device), (8, 8))
                    labels = labels.to(device)
                    try:
                        loss = loss_func_test(pred, labels, reduction='sum')
                    except RuntimeError as e:
                        labels = labels.reshape((-1,))
                        loss = loss_func_test(pred, labels, reduction='sum')
                    loss_val += loss / num_val
                    del labels
                    del pred
                    torch.cuda.empty_cache()
            loss_test = 0
            for dataloader in dataloader_test_list:
                for features, labels, matrices in dataloader:
                    pred = model(features.to(device), matrices.to(device), (8, 8))
                    labels = labels.to(device)
                    try:
                        loss = loss_func_test(pred, labels, reduction='sum')
                    except RuntimeError as e:
                        labels = labels.reshape((-1,))
                        loss = loss_func_test(pred, labels, reduction='sum')
                    loss_test += loss / num_test
                    del labels
                    del pred
                    torch.cuda.empty_cache()
        logs = {
            'epoch': epoch+1,
            'loss_val': loss_val,
            'loss_test': loss_test,
            'loss_train': loss_train,
            'time': t1-t0,
        }
        wandb.log(logs)
        print(f'Epoch {epoch+1} | Time {t1 - t0} | Loss train {loss_train} | Loss val {loss_val} | Loss test {loss_test}')
    wandb.finish()
