import numpy as np
import torch
from torch_geometric.data import Data, InMemoryDataset, download_url
from torch_geometric.utils import get_laplacian
from typing import Optional
import torch
import torch.nn.functional as F


def compute_coulomb(z, r, eps=1e-7):
    # subdiag
    z_sq = z[:, :, None] * z[:, None, :]
    r_dist = np.linalg.norm(r[:, :, None, :] - r[:, None, :, :], axis=-1)
    c = z_sq / (r_dist + eps)
    
    # diag
    idx_diag = np.repeat(np.expand_dims(np.eye(c.shape[1]), axis=0), c.shape[0], axis=0)==1
    c[idx_diag] = z.reshape(-1)**2.4 / 2

    return c


def preprocess_data_list(coulomb_matrix, target, omega, nf):
    data_list = []
    for i in range(coulomb_matrix.shape[0]):
        edge_index = coulomb_matrix[i].nonzero(
            as_tuple=False).t().contiguous()
        edge_attr = coulomb_matrix[i, edge_index[0], edge_index[1]]
        C = torch.sparse_coo_tensor(edge_index, edge_attr).to_dense()   # Generating dense Coulomb-matrix with charges encoded on diagonal elements
        x = torch.Tensor((2*torch.diag(C,0))**(1/2.4)) # Recovering Charges
        Z = x.reshape(len(x),1)
        Z_one_hot = torch.squeeze(F.one_hot(Z.to(torch.int64), num_classes=17),1).to(Z.dtype)
        inverse_weight_matrix = torch.inverse(torch.diag_embed(torch.squeeze(Z)))
        L_indices = get_laplacian(edge_index, edge_attr)
        L =  torch.sparse_coo_tensor(L_indices[0], L_indices[1])
        L = L.to_dense()
        weighted_L =  torch.matmul(inverse_weight_matrix, L)
        identity = torch.eye(L.size(0)).to_sparse_coo()
        T = weighted_L/nf - omega*identity
        R = torch.linalg.inv(T.to_dense()).to_sparse()

        Redge_index, Redge_attr = R.indices(), R.values()

        y = target[i].view(1, -1)
        data = Data(edge_index=edge_index, edge_attr=edge_attr, Z=Z, y=y, r_edge_index = Redge_index, r_edge_attr=Redge_attr, Z_one_hot = Z_one_hot)
        data.num_nodes = edge_index.max().item() + 1
        data_list.append(data)
    return data_list
        

class QM7DataWithResolvents(InMemoryDataset):
    url = 'http://quantum-machine.org/data/qm7.mat'
    def __init__(
        self,
        root: str,
        omega: float,
        nf: float,
        transform: Optional[callable] = None,
        pre_transform: Optional[callable] = None,
        pre_filter: Optional[callable] = None,
    ):
        self.root = root
        self.omega = omega
        self.nf = nf
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load( self.root  + '/processed/'+ 'omega' + str(self.omega) + '_nf' + str(self.nf)+  '.pt')

    @property
    def raw_file_names(self) -> str:
        return 'qm7.mat'

    @property
    def processed_file_names(self) -> str:
        return 'not_implemented.pt'

    def download(self):
        download_url(self.url, self.raw_dir)

    def process(self):
        from scipy.io import loadmat
        print(f'Calculating Resolvents at omega = {self.omega} and normalizing factor {self.nf}' )
        assert self.nf > 0
        data_in = loadmat(self.raw_paths[0])
        coulomb_matrix_numpy = compute_coulomb(data_in['Z'], data_in['R'])
        coulomb_matrix = torch.from_numpy(coulomb_matrix_numpy)
        target = torch.transpose(torch.from_numpy(data_in['T']).to(torch.float),0,1)

        data_list = preprocess_data_list(coulomb_matrix, target, self.omega, self.nf)

        if self.pre_filter is not None:
            data_list = [d for d in data_list if self.pre_filter(d)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(d) for d in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.root  + '/processed/'+ 'omega' + str(self.omega) + '_nf' + str(self.nf)+  '.pt')
