from typing import Optional, Callable, Union, List, Tuple, Dict
from os.path import join as ospjoin
import torch
from torch import Tensor, sparse_coo_tensor, arange, index_select, eq, eye
from torch.linalg import norm as l2norm, inv
from torch.sparse import mm as spmm
from torch_geometric import datasets
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.utils import get_laplacian


def spdiag(diag_values: Tensor) -> Tensor:
    n = len(diag_values)
    idx = arange(n, device=diag_values.device).repeat(2, 1)
    return sparse_coo_tensor(idx, diag_values, size=(n, n))


def compute_resolv_edge_count(node_slice: Tensor) -> int:
    total = 0
    for s, e in zip(node_slice[:-1], node_slice[1:]):
        total += (e-s) ** 2
    return total


# datasets variant for (QM7b, QM9) with the following data format
# - z (n_atom,), float: proton count for each atom
# - y (n_mo, n_target), float: regression targets for each molecule
# - edge_index (2, n_edges), int
# - edge_attr (n_edges,), float: off-diagonal C_ij as in
#   http://quantum-machine.org/datasets/


def qm7b_pre_transform(data: Data) -> Data:
    e_idx, e_attr = data.edge_index, data.edge_attr
    is_diag = eq(e_idx[0, :], e_idx[1, :])
    diag_attr = e_attr[is_diag.nonzero().flatten()]
    # inverse of C_ii according to http://quantum-machine.org/datasets/
    z = (diag_attr * 2) ** (1. / 2.4)
    # exclude diagonal entries from edge info
    nondiag_idx = (~is_diag).nonzero().flatten()
    new_idx = index_select(e_idx, 1, nondiag_idx)
    new_attr = e_attr[nondiag_idx]
    return Data(edge_index=new_idx, edge_attr=new_attr, y=data.y, z=z)


class QM7b(datasets.QM7b):
    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
    ):
        super().__init__(root, transform, qm7b_pre_transform, None)


def qm9_pre_transform(data: Data) -> Data:
    e_idx, pos, z = data.edge_index, data.pos, data.z
    is_diag = eq(e_idx[0, :], e_idx[1, :])
    # exclude diagonal entries from edge index
    new_idx = index_select(e_idx, 1, (~is_diag).nonzero().flatten())
    # C_ij according to http://quantum-machine.org/datasets/
    z_i, z_j = z[new_idx[0]], z[new_idx[1]]
    r_i = index_select(pos, 0, new_idx[0])
    r_j = index_select(pos, 0, new_idx[1])
    c_ij = z_i * z_j / l2norm(r_i - r_j, dim=1)
    return Data(edge_index=new_idx, edge_attr=c_ij, y=data.y, z=z)


class QM9(datasets.QM9):
    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None
    ):
        super().__init__(root, transform, qm9_pre_transform, None)


CustomDataset = Union[QM7b, QM9]

# # sparse when possible, slower version (40.63 s)
# def resolv_graph(data: Data, omega: float, nf: float) -> Data:
#     e_idx, e_attr, y, z = data.edge_index, data.edge_attr, data.y, data.z
#     n_node = len(z)
#     inv_z = 1. / z
#     laplace = sparse_coo_tensor(
#         *get_laplacian(e_idx, e_attr), size=(n_node, n_node))
#     weighted_laplace = spmm(spdiag(inv_z), laplace)
#     t = weighted_laplace / nf - omega * spdiag(torch.ones_like(z))
#     r = inv(t.to_dense()).to_sparse()
#     return Data(edge_index=r.indices(), edge_attr=r.values(), y=y, z=z)


# dense when convenient, faster (26.24 s)
def resolv_graph(data: Data, omega: float, nf: float) -> Data:
    e_idx, e_attr, y, z = data.edge_index, data.edge_attr, data.y, data.z
    n_node = len(z)
    inv_z = 1. / z
    laplace = sparse_coo_tensor(
        *get_laplacian(e_idx, e_attr), size=(n_node, n_node)).to_dense()
    weighted_laplace = laplace * inv_z.unsqueeze(1)
    t = weighted_laplace / nf - omega * eye(n_node)
    r = inv(t).to_sparse()
    return Data(edge_index=r.indices(), edge_attr=r.values(), y=y, z=z)


# preprocess any dataset from above for our model
class ResolvDataset(InMemoryDataset):
    def __init__(
            self,
            dataset: CustomDataset,
            omega: float,
            nf: float):
        super().__init__(
            ospjoin(dataset.root, 'resolv'), None, None, None, False)
        self.omega = omega
        self.nf = nf
        self.data, self.slices = self._resolv(dataset)

    @property
    def raw_file_names(self) -> Union[str, List[str], Tuple]:
        return ()  # this dataset purely in-memory

    @property
    def processed_file_names(self) -> Union[str, List[str], Tuple]:
        return ()  # this dataset purely in-memory

    def process(self):
        pass

    def download(self):
        pass

    def _resolv(
            self, dataset: CustomDataset
    ) -> Tuple[Data, Optional[Dict[str, Tensor]]]:
        omega, nf = self.omega, self.nf
        data_list = [resolv_graph(d, omega, nf) for d in dataset]
        return InMemoryDataset.collate(data_list)


if __name__ == '__main__':
    from timeit import default_timer

    print('\nQM7b:')
    dataset = datasets.QM7b('./data/qm7b')
    print(dataset.__dict__.keys())
    # print(dataset._indices)
    # print(dataset._data_list)
    print(dataset.data)
    print(dataset.slices.keys())
    print(len(dataset.slices['edge_index']))
    print(dataset.slices['edge_index'])

    print('\nCustom QM7b:')
    dataset = QM7b('./data/custom_qm7b')
    print(dataset.__dict__.keys())
    print(dataset._indices)
    print(dataset._data_list)
    print(dataset.data)
    print(dataset.data.z[:5])
    print(dataset.slices.keys())
    print(len(dataset.slices['edge_index']))
    print(dataset.slices['edge_index'])

    print('\nQM9:')
    dataset = datasets.QM9('./data/qm9')
    print(dataset.__dict__.keys())
    print(dataset._indices)
    print(dataset._data_list)
    print(dataset.data)
    print(dataset.data.edge_attr[0])
    print(dataset.slices.keys())
    print(dataset.slices['edge_index'])
    print(dataset.slices['x'])
    print(f"Resolv edge count: {compute_resolv_edge_count(dataset.slices['x'])}")

    print('\nCustom QM9:')
    dataset = QM9('./data/custom_qm9')
    print(dataset.__dict__.keys())
    print(dataset._indices)
    print(dataset._data_list)
    print(dataset.data)
    print(dataset.data.edge_attr[0])
    print(dataset.slices.keys())
    print(dataset.slices['edge_index'])
    print(dataset.slices['z'])
    print(f"Resolv edge count: {compute_resolv_edge_count(dataset.slices['z'])}")

    start = default_timer()
    print('\nResolv Custom QM9:')
    dataset = ResolvDataset(dataset, -1.0, 1.0)
    print(dataset.__dict__.keys())
    print(dataset._indices)
    print(dataset._data_list)
    print(dataset.data)
    print(dataset.data.edge_attr[0])
    print(dataset.slices.keys())
    print(dataset.slices['edge_index'])
    print(dataset.slices['z'])
    print(f'time elapsed: {default_timer() - start}')  #
