import torch
from torch.utils.data import Dataset
import dgl
from os import path
import networkx as nx
import numpy as np
from dgl.data import DGLDataset
from utils import generate_ising_matrices_torch, obs_ZZ
from dgl.dataloading import GraphCollator
from itertools import combinations



datasets_path = '../datasets'

def load_dataset(name, min_node=0, max_node=12, node_attributes=False, edge_attributes=False):

    '''
    Loads the dataset with the corresponding name. Creates an array of the
    graph with nodes and edge attributes, and an array of targets.
    Details about the file formats here:
        https://chrsmrrs.github.io/datasets/docs/format/
    The datasets can be downloaded here:
        https://chrsmrrs.github.io/datasets/docs/datasets/

    Arguments:
    ---------
    - name: str, name of the dataset
    - min_node: int, eliminate all the graphs with
                    a number of nodes below the value passed
    - max_node: int, eliminate all the graphs with
                    a number of nodes above the value passed

    Returns:
    --------
    - graph_filtered: numpy.Ndarray of networkx.Graph objects,
     all nodes attributes and edge attributes are stored in the key 'attr'
    - targets_filtered: numpy.Ndarray of floats,
     discrete values for classification, continuous ones for regression
    '''

    directory = datasets_path + '/' + name + '/' + name + '_'
    is_node_attr = False
    is_edge_attr = False
    is_node_label = False
    is_edge_label = False

    with open(directory + 'graph_indicator.txt') as file:
        all_nodes = np.array(file.read().splitlines()).astype(int)

    with open(directory + 'A.txt') as file:
        all_edges = np.array(file.read().splitlines())

    if path.exists(directory + 'graph_labels.txt'):
        targets = np.loadtxt(directory + 'graph_labels.txt', delimiter=',')

    if path.exists(directory + 'graph_attributes.txt'):
        targets = np.loadtxt(directory + 'graph_attributes.txt', delimiter=',')

    if edge_attributes:
        if path.exists(directory + 'edge_attributes.txt'):
            edge_attr = np.loadtxt(directory
                                + 'edge_attributes.txt', delimiter=',')
            is_edge_attr = True

        if path.exists(directory + 'edge_attributes.txt'):
            edge_label = np.loadtxt(directory
                                + 'edge_attributes.txt', delimiter=',')
            is_edge_attr = True

    if node_attributes:
        if path.exists(directory + 'node_attributes.txt'):
            node_attr = np.loadtxt(directory
                                + 'node_attributes.txt', delimiter=',')
            is_node_attr = True

        if path.exists(directory + 'node_labels.txt'):
            node_label = np.loadtxt(directory
                                + 'node_labels.txt', delimiter=',')
            is_node_label = True


    l = []
    for edge in all_edges:
        # print(edge)
        edge = edge.replace(' ', '')
        l.append((int(edge.split(',')[0]), int(edge.split(',')[1])))
    # l = [(int(edge.split(', ')[0]), int(edge.split(', ')[1]))
    # for edge in all_edges]
    all_edges = np.array(l)

    all_graphs = [nx.Graph() for i in range(np.max(all_nodes))]

    for i, graph_id in enumerate(all_nodes):
        kwargs = dict()
        if is_node_attr:
            kwargs['attr'] = node_attr[i]
        if is_node_label:
            kwargs['label'] = node_label[i]
        all_graphs[graph_id-1].add_node(i+1, **kwargs)

    for i, edge in enumerate(all_edges):
        node_id = edge[0]-1
        graph_id = all_nodes[node_id]
        if is_edge_attr:
            all_graphs[graph_id-1].add_edge(*edge, attr=edge_attr[i])
        else:
            all_graphs[graph_id-1].add_edge(*edge)

    all_graphs = np.array(all_graphs, dtype=object)

    graph_filtered_id = np.array(
        [i for i, graph in enumerate(all_graphs)
         if (graph.number_of_nodes() <= max_node)
         & (graph.number_of_nodes() >= min_node)]).astype(int)

    graph_filtered = all_graphs[graph_filtered_id]
    targets_filtered = targets[graph_filtered_id]

    return graph_filtered, targets_filtered

def collator(items):
    coll_g = GraphCollator()
    batched_g = coll_g.collate([item[0] for item in items])
    batched_target = torch.cat([item[1] for item in items])
    batched_ising = np.array([item[2] for item in items], dtype=object)
    batched_feat = [item[3] for item in items]
    batched_index = torch.tensor([item[4] for item in items])
    return batched_g, batched_target, batched_ising, batched_feat, batched_index

def collator_same_size(items):
    nodes = np.array([item[0].number_of_nodes() for item in items])
    batched_g, batched_target, batched_ising,\
        batched_feat, batched_index = collator(items)
    if len(np.unique(nodes)) == 1:
        batched_ising = torch.cat([item[2].reshape(-1, 1) for item in items], dim=1)
    return batched_g, batched_target, batched_ising, batched_feat, batched_index

def collator_classical(items):
    coll_g = GraphCollator()
    batched_g = coll_g.collate([item[0] for item in items])
    batched_target = torch.cat([item[1] for item in items])
    batched_index = torch.tensor([item[-1] for item in items])
    return batched_g, batched_target, batched_index


class NXDataset(DGLDataset):

    def __init__(self, raw_graphs, raw_labels, max_node, name='mydataset',
                 force_reload=False, verbose=False,
                 shuffle=False, seed=87, device='cpu', self_loop=False, 
                 classification=True, compute_ising=True):
        self.raw_graphs = raw_graphs
        self.raw_labels = raw_labels
        self.device = torch.device(device)
        self.self_loop = self_loop
        self.classification_dataset = classification
        super(NXDataset, self).__init__(name=name,
                                        force_reload=force_reload,
                                        verbose=verbose)
        self.num_classes = len(np.unique(raw_labels))
        

        if shuffle:
            np.random.seed(seed)
            self.index = np.random.permutation(
                                                len(self.raw_graphs)
                                              ).astype(int)
        else:
            self.index = np.arange(len(self.raw_graphs)).astype(int)
        if compute_ising:
            self.compute_ising_matrices(max_node, device='cpu')
        else:
            self.ising_matrices = [None]*len(raw_graphs)

    def process(self):
        if self.self_loop:
            self.graphs = [dgl.from_networkx(graph).add_self_loop()
                           .to(self.device) for graph in self.raw_graphs]
        else:
            self.graphs = [dgl.from_networkx(graph)
                           .to(self.device) for graph in self.raw_graphs]
        for graph in self.graphs:
            deg = graph.out_degrees().cpu().numpy()
            L = np.diag(deg) - graph.adj().to_dense().cpu().numpy()
            l, v = np.linalg.eigh(L)
            embed = v[:, 0:4]
            graph.ndata['feat'] = torch.tensor(embed,
                                               dtype=torch.float,
                                               device=self.device)

        if self.classification_dataset:
            for i, target in enumerate(np.unique(self.raw_labels)):
                self.raw_labels[self.raw_labels == target] = i
            self.label = [torch.tensor(label,
                                    dtype=torch.long,
                                    device=self.device)
                        for label in self.raw_labels]
        else:
            self.label = [torch.tensor(label.reshape((1, -1)),
                                    dtype=torch.float,
                                    device=self.device)
                        for label in self.raw_labels]
        self.graph_features = [dict() for _ in range(len(self.graphs))]

    def compute_ising_matrices(self, max_node, device='cpu'):
        NN_matrices = dict()
        for n in range(2, max_node+1):
            matrices = dict()
            for i in range(n):
                for j in range(i, n):
                    matrix = obs_ZZ(n, i, j, type_n=True)
                    matrices[(i, j)] = matrix
            NN_matrices[n] = matrices
            del matrices

        self.ising_matrices = generate_ising_matrices_torch(
                                    self.raw_graphs,
                                    precomputed_zz=NN_matrices,
                                    device=device)

    def compute_obs_shortest_path(self, max_node, device='cpu'):
        NN_matrices = dict()
        for n in range(2, max_node+1):
            matrices = dict()
            for i in range(n):
                for j in range(i, n):
                    matrix = obs_ZZ(n, i, j, type_n=True)
                    matrices[(i, j)] = matrix.to(device)
            NN_matrices[n] = matrices
            del matrices

        for index, graph in enumerate(self.raw_graphs):
            N = graph.number_of_nodes()
            obs = torch.zeros((N, N, 2**N)).to(device)
            for i, j in combinations(range(N), 2):
                break
                try:
                    path = nx.shortest_path(graph, i, j)
                    for k in range(1, len(path)):
                        if (path[k], path[k-1]) in NN_matrices[N]:
                            obs[i, j] += NN_matrices[N][(path[k], path[k-1])]\
                                / (len(path) - 1)
                        else:
                            obs[i, j] += NN_matrices[N][(path[k-1], path[k])] \
                                / (len(path) - 1)
                except nx.NetworkXNoPath:
                    pass
                obs[j, i] = obs[i, j]
            #self.graph_features[index]['obs'] = obs
            self.graph_features[index]['obs'] = torch.zeros((1, 1, 1)).to(device)


    def update_graph_features(self, feat, index, feat_name):
        self.graph_features[index][feat_name] = feat.clone().cpu()

    def __getitem__(self, idx):
        """ Get graph and label by index

        Parameters
        ----------
        idx : int
            Item index

        Returns
        -------
        (dgl.DGLGraph, Tensor)
        """
        return (self.graphs[self.index[idx]],
                self.label[self.index[idx]],
             #   self.ising_matrices[self.index[idx]],
             #   self.graph_features[self.index[idx]],
                self.index[idx])

    def __len__(self):
        """Number of graphs in the dataset"""
        return len(self.graphs)

    def shuffle(self, seed=84):
        np.random.seed(seed)
        self.index = np.random.permutation(len(self.raw_graphs)).astype(int)

    def to(self, device):
        self.device = device
        new_graphs = [graph.to(self.device) for graph in self.graphs]
        self.graphs = new_graphs
        for graph in self.graphs:
            graph.ndata['feat'] = graph.ndata['feat'].to(self.device)
            graph.edata['w'] = graph.edata['w'].to(self.device)
        self.label = [torch.tensor(label,
                                   dtype=torch.long,
                                   device=self.device)
                      for label in self.raw_labels]


class CustomDataset(Dataset):

    """Dataset of graphs with all the graphs of the same size

    Parameters
    ----------
    raw_graphs : list of dgl.DGLGraph
        List of graphs of N nodes
    raw_labels : Tensor
        Labels of the graphs
    ising_matrices : Tensor of shape (2**N, batch_size)
        Tensor of ising matrices
    n_node : int
        Number of nodes in the graphs
    n_attenion : int
        Number of attention matrices to be stored
        default = 2
    """

    def __init__(self, raw_graphs, raw_labels, ising_matrices, n_nodes, feat_name='feat', out_dim=None):
        super(CustomDataset, self).__init__()
        self.features = torch.cat([graph.ndata[feat_name].clone().unsqueeze(0) for graph in raw_graphs]).float()
        self.adjacency_matrices = torch.cat([graph.adj().to_dense().clone().unsqueeze(0) for graph in raw_graphs])
        self.labels = torch.tensor(raw_labels).clone()
        # if len(self.labels.shape)==1:
        #     self.labels = self.labels.reshape((-1, 1))
        self.n_nodes = n_nodes
        self.ising_matrices = ising_matrices.clone()
        self.in_dim = self.features.shape[2]
        if out_dim is None:
            self.out_dim = self.labels.shape[1]
        else:
            self.out_dim  = out_dim
        # self.attention_matrices = torch.ones((self.labels.shape[0], n_attention, n_nodes, n_nodes))

    def __getitem__(self, idx):
        """ Get elements of the dataset by index

        Parameters
        ----------
        idx : int
            Item index

        Returns
        -------
        (features, labels, ising_matrix, adjacency_matrix, attention_matrix, index)
        (Tensor, Tensor, Tensor, Tensor, Tensor, int)
        shape of features: (batch_size, n_node, features_dim)
        shape of labels: (batch_size, labels_dim)
        shape of ising_matrix: (2**n_node, batch_size)
        shape of adjacency_matrix: (batch_size, n_node, n_node)
        shape of attention_matrix: (batch_size, n_attention, n_node, n_node)

        """
        return self.features[idx], self.labels[idx], self.ising_matrices[:, idx], self.adjacency_matrices[idx], idx

    def __len__(self):
        """Number of graphs in the dataset"""
        return self.labels.shape[0]

    def update_attention_matrices(self, new_matrices, idx, idx_batch):
        """ Updatae the value of the attention matrices

        Parameters
        ----------
        new_matrices : Tensor of shape (batch_size, n_node, n_node)
            New values of the matrices
        idx : int
            Item index
        idx_batch : list of int

        """
        self.attention_matrices[idx_batch, idx] = new_matrices.cpu().clone()


class CustomDatasetClassical(Dataset):

    def __init__(self, raw_graphs, raw_labels, n_nodes, att_shape=2, feat_name='feat', out_dim=None):
        super(CustomDatasetClassical, self).__init__()
        self.features = torch.cat([graph.ndata[feat_name].clone().unsqueeze(0) for graph in raw_graphs]).float()
        self.adjacency_matrices = torch.cat([graph.adj().to_dense().clone().unsqueeze(0) for graph in raw_graphs])
        self.labels = torch.tensor(raw_labels).clone()
        # if len(self.labels.shape)==1:
        #     self.labels = self.labels.reshape((-1, 1))
        self.n_nodes = n_nodes
        self.att_shape = att_shape
        if isinstance(att_shape, int):
            att_shape = [att_shape]
        attention_shape = [self.labels.shape[0]] + list(att_shape) + [n_nodes, n_nodes]
        self.attention_matrices = torch.ones(attention_shape)
        

    def __getitem__(self, idx):
        """ Get elements of the dataset by index

        Parameters
        ----------
        idx : int
            Item index

        Returns
        -------
        (features, labels, adjacency_matrix, attention_matrix, index)
        (Tensor, Tensor, Tensor, Tensor, int)
        shape of features: (batch_size, n_node, features_dim)
        shape of labels: (batch_size, labels_dim)
        shape of adjacency_matrix: (batch_size, n_node, n_node)
        shape of attention_matrix: (batch_size, 2, n_node, n_node)
        """
        return self.features[idx], self.labels[idx], self.adjacency_matrices[idx], self.attention_matrices[idx], idx

    def __len__(self):
        """Number of graphs in the dataset"""
        return self.labels.shape[0]

    def update_attention_matrices(self, new_matrices, idx, idx_batch):
        if isinstance(idx, int):
            idx = [idx]
        self.attention_matrices[[idx_batch, *idx]] = new_matrices.cpu().clone()


class DatasetClassical(DGLDataset):

    def __init__(self, graphs, targets, name='mydataset',
                 force_reload=False, verbose=False,
                 shuffle=False, seed=87, device='cpu', self_loop=False, 
                 classification=True, compute_ising=True):
        self.graphs = graphs
        self.targets = targets

        super(DatasetClassical, self).__init__(name=name,
                                        force_reload=force_reload,
                                        verbose=verbose)

    def __getitem__(self, idx):
        return self.graphs[idx], self.targets[idx]


def build_datasets(graphs_list, targets_list, ising_matrices_list, nodes_list, attention_params, feat_name='feat', out_dim=None):
    assert len(graphs_list) == len(nodes_list)
    dataset_list = []
    dataset_classical_list = []
    for graphs, targets, ising_matrices, n in zip(graphs_list, targets_list, ising_matrices_list, nodes_list):
        dataset = CustomDataset(graphs, targets, ising_matrices, n, feat_name=feat_name, out_dim=out_dim)
        dataset_classical = CustomDatasetClassical(graphs, targets, n, attention_params, feat_name=feat_name, out_dim=out_dim)
        dataset_list.append(dataset)
        dataset_classical_list.append(dataset_classical)
    return dataset_list, dataset_classical_list
