from typing import Optional, Callable, Union, List, Tuple, Dict
from os.path import exists as ospexists, join as ospjoin
import torch
import torch.nn.functional as F
from torch import Tensor, sparse_coo_tensor, arange, index_select, eq, eye
from torch.linalg import norm as l2norm, inv
from torch_geometric import datasets
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.utils import get_laplacian


def spdiag(diag_values: Tensor) -> Tensor:
    n = len(diag_values)
    idx = arange(n, device=diag_values.device).repeat(2, 1)
    return sparse_coo_tensor(idx, diag_values, size=(n, n))


def compute_resolv_edge_count(node_slice: Tensor) -> int:
    total = 0
    for s, e in zip(node_slice[:-1], node_slice[1:]):
        total += (e-s) ** 2
    return total


def qm7b_pre_transform(data: Data) -> Data:
    e_idx, e_attr = data.edge_index, data.edge_attr
    is_diag = eq(e_idx[0, :], e_idx[1, :])
    diag_attr = e_attr[is_diag.nonzero().flatten()]
    # inverse of C_ii according to http://quantum-machine.org/datasets/
    z = (diag_attr * 2) ** (1. / 2.4)
    # exclude diagonal entries from edge info
    nondiag_idx = (~is_diag).nonzero().flatten()
    new_idx = index_select(e_idx, 1, nondiag_idx)
    new_attr = e_attr[nondiag_idx]
    data = Data(edge_index=new_idx, edge_attr=new_attr, y=data.y, z=z)
    data.num_nodes = len(z)
    return data


class QM7b(datasets.QM7b):
    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
    ):
        super().__init__(root, transform, qm7b_pre_transform, None)


def qm9_pre_transform(data: Data) -> Data:
    e_idx, pos, z = data.edge_index, data.pos, data.z.to(
        dtype=torch.get_default_dtype())
    is_diag = eq(e_idx[0, :], e_idx[1, :])
    # exclude diagonal entries from edge index
    new_idx = index_select(e_idx, 1, (~is_diag).nonzero().flatten())
    # C_ij according to http://quantum-machine.org/datasets/
    z_i, z_j = z[new_idx[0]], z[new_idx[1]]
    r_i = index_select(pos, 0, new_idx[0])
    r_j = index_select(pos, 0, new_idx[1])
    c_ij = z_i * z_j / l2norm(r_i - r_j, dim=1)
    data = Data(edge_index=new_idx, edge_attr=c_ij, y=data.y, z=z.unsqueeze(1))
    data.num_nodes = len(z)
    return data


class QM9(datasets.QM9):
    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None
    ):
        super().__init__(root, transform, qm9_pre_transform, None)


CustomDataset = Union[QM7b, QM9]


def resolv_graph(data: Data, omega: float, nf: float) -> Data:
    e_idx, e_attr, y, z = \
        data.edge_index, data.edge_attr, data.y, data.z.flatten()
    charges = data.z
    # print(charges)
    z_one_hot = torch.squeeze(F.one_hot(charges.to(torch.int64), num_classes=10),1).to(torch.float)
    n_node = len(z)
    inv_z = 1. / z
    laplace = sparse_coo_tensor(
        *get_laplacian(e_idx, e_attr), size=(n_node, n_node)).to_dense()
    weighted_laplace = laplace * inv_z.unsqueeze(1)
    t = weighted_laplace / nf - omega * eye(n_node)
    r = inv(t).to_sparse()
    data = Data(
        edge_index=e_idx, edge_attr=e_attr, r_edge_index=r.indices(),
        r_edge_attr=r.values(), y=y, z=z.unsqueeze(1), z_one_hot = z_one_hot)
    data.num_nodes = len(z)
    return data


class ResolvDataset(InMemoryDataset):
    def __init__(
            self,
            dataset: CustomDataset,
            omega: float,
            nf: float,
            log: bool = True,
            save: bool = False):
        self.omega = float(omega)
        self.nf = float(nf)
        self.save = save
        super().__init__(
            # ospjoin(dataset.root, 'resolv'), None, None, None, log)
            ospjoin(dataset.root, 'resolv'), None, None, None)
        if log:
            print('resolving ... ', end='')
        self.data, self.slices = self._resolv(dataset)
        if log:
            print('done!')

    @property
    def raw_file_names(self) -> Union[str, List[str], Tuple]:
        return ()  # this dataset is purely in-memory

    @property
    def processed_file_names(self) -> Union[str, List[str], Tuple]:
        return f'data_{self.omega}_{self.nf}.pt'

    def process(self):
        pass

    def download(self):
        pass

    def _resolv(
            self, dataset: CustomDataset
    ) -> Tuple[Data, Optional[Dict[str, Tensor]]]:
        processed_filename = self.processed_paths[0]
        # # # # # if ospexists(processed_filename):
        # # # # #     return torch.load(processed_filename)
        omega, nf = self.omega, self.nf
        data_list = [resolv_graph(d, omega, nf) for d in dataset]
        data, slices = self.collate(data_list)
        if self.save:
            torch.save((data, slices), processed_filename)
        return data, slices
    

def max_norm(dataset):
    max_norm = 0 
    for data in dataset:

        r_dense = sparse_coo_tensor(
            data.edge_index, data.edge_attr, size=(len(data.z),) * 2
        ).to_dense()
        r_norm = torch.linalg.norm(r_dense, ord = 2)
        if r_norm > max_norm:
            max_norm = r_norm
    return max_norm
    

def max_norm_(dataset):
    def _compute_norm(data):
        r_dense = sparse_coo_tensor(
            data.edge_index, data.edge_attr, size=(len(data.z),) * 2
        ).to_dense()
        return torch.linalg.norm(r_dense, ord = 2)

    return max(_compute_norm(data) for data in dataset)


def max_entry_(dataset):
    def _compute_max_weight(data):
        return data.edge_attr.max().item()

    return max(_compute_max_weight(data) for data in dataset)


def min_entry_(dataset):
    def _compute_min_weight(data):
        return data.edge_attr.max().item()

    return min(_compute_min_weight(data) for data in dataset)

def max_charge_(dataset):
    def _compute_max_z(data):
        return data.z.max().item()

    return max(_compute_max_z(data) for data in dataset)


if __name__ == '__main__':
    from timeit import default_timer

    print('\nQM7b:')
    dataset = datasets.QM7b('./data/qm7b')
    print(dataset.__dict__.keys())
    print(dataset._indices)
    print(dataset._data_list)
    print(dataset.data)
    print(dataset.slices.keys())
    print(len(dataset.slices['edge_index']))
    print(dataset.slices['edge_index'])

    print('\nCustom QM7b:')
    dataset = QM7b('./data/custom_qm7b')
    print(dataset.__dict__.keys())
    print(dataset._indices)
    print(dataset._data_list)
    print(dataset.data)
    print(dataset.data.z[:5])
    print(dataset.slices.keys())
    print(len(dataset.slices['edge_index']))
    print(dataset.slices['edge_index'])

    print('\nQM9:')
    dataset = datasets.QM9('./data/qm9')
    print(dataset.__dict__.keys())
    print(dataset._indices)
    print(dataset._data_list)
    print(dataset.data)
    print(dataset.data.edge_attr[0])
    print(dataset.slices.keys())
    print(dataset.slices['edge_index'])
    print(dataset.slices['x'])
    print(
        f"Resolv edge count: {compute_resolv_edge_count(dataset.slices['x'])}")

    print('\nCustom QM9:')
    dataset = QM9('./data/custom_qm9')
    print(dataset.__dict__.keys())
    print(dataset._indices)
    print(dataset._data_list)
    print(dataset.data)
    print(dataset.data.edge_attr[0])
    print(dataset.slices.keys())
    print(dataset.slices['edge_index'])
    print(dataset.slices['z'])
    print(
        f"Resolv edge count: {compute_resolv_edge_count(dataset.slices['z'])}")

    start = default_timer()
    print('\nResolv Custom QM9:')
    dataset = ResolvDataset(dataset, -1.0, 1.0)
    print(dataset.__dict__.keys())
    print(dataset._indices)
    print(dataset._data_list)
    print(dataset.data)
    print(dataset.data.edge_attr[0])
    print(dataset.data.z_one_hot)
    print(dataset.slices.keys())
    print(dataset.slices['edge_index'])
    print(dataset.slices['z'])
    print(f'datataset.z {dataset.z}')
    print(max_entry_(dataset))
    print(min_entry_(dataset))
    print(f'max. charge is {max_charge_(dataset)}')
    print(f'time elapsed: {default_timer() - start}')  #

    print()
