import os
import lmdb
import pickle

from functools import lru_cache

import numpy as np

import torch
from torch.utils.data import Dataset

from gto import element_to_atomic_number, GTOBasis, GTOAuxDensityHelper, GTOProductBasisHelper


def compute_edge_index(coords, r_max, remove_self_loops=True):
    from scipy.spatial import distance_matrix
    dist = distance_matrix(coords, coords)
    edge_index = np.stack(np.nonzero(dist < r_max), axis=0)
    if remove_self_loops:
        edge_index = edge_index[:, edge_index[0] != edge_index[1]]
    return edge_index


class ShardedLMDBDataset(Dataset):
    def __init__(self, data_root: str):
        super().__init__()
        self.data_root = data_root
        if os.path.isfile(os.path.join(data_root, 'data.lmdb')):
            self.shards = ['.']
        else:
            self.shards = sorted(os.listdir(data_root))
        envs = self.get_envs()
        self.env_lengths = [env.stat()["entries"] for env in envs]
        self.env_boundaries = np.cumsum(self.env_lengths)
        self.len = self.env_boundaries[-1]

        self.envs = None  # postpone env intitialization until ddp is intitialized

    def get_envs(self):
        return [
            lmdb.Environment(
                os.path.join(self.data_root, shard, 'data.lmdb'),
                map_size=(1024 ** 3) * 256,
                subdir=False,
                readonly=True,
                readahead=True,
                meminit=False,
                lock=False,
            )
            for shard in self.shards
        ]

    def __len__(self):
        return self.len

    def __getitem__(self, index: int):
        if self.envs is None:
            self.envs = self.get_envs()
        if index < 0 or index >= self.len:
            raise IndexError
        env_idx = np.searchsorted(self.env_boundaries, index, 'right')
        data_idx = index - (self.env_boundaries[env_idx - 1] if env_idx != 0 else 0)
        x = pickle.loads(self.envs[env_idx].begin(write=False).get(f"{data_idx}".encode()))
        return x


class MultipartLMDBDataset(Dataset):
    def __init__(self, data_root: str, parts_to_load: list[str] = ['base']):
        super().__init__()
        self.data_root = data_root
        self.subdatasets = {
            part: ShardedLMDBDataset(os.path.join(data_root, part))
            for part in parts_to_load
        }
        self.len = len(next(iter(self.subdatasets.values())))
        assert all(len(subdataset) == self.len for subdataset in self.subdatasets.values())

    def __len__(self):
        return self.len

    def __getitem__(self, index: int):
        ret = {}
        for part, subdataset in self.subdatasets.items():
            ret.update(subdataset[index])
        return ret


class SCFBenchDataset(Dataset):
    """
    Unit assumption:
        atomic coordinates: angstrom
        multipole moments: atomic unit
        auxdensity: atomic unit
        dm: atomic unit
        fock: atomic unit
    """
    def __init__(
        self,
        data_root,
        r_max=5.0,
        type_names=['H', 'C', 'N', 'O', 'F', 'P', 'S'],
        remove_self_loops=True,
        parts_to_load=['base', 'dm', 'fock', 'auxdensity.denfit'],
        aobasis='def2-svp',
        auxbasis='def2-universal-jfit',
    ):
        super().__init__()

        self.data_root = data_root
        self.parts_to_load = parts_to_load

        self.dataset = MultipartLMDBDataset(self.data_root, parts_to_load=self.parts_to_load)

        self.type_names = type_names
        self.atom_numbers = [element_to_atomic_number[e] for e in self.type_names]
        self.atom_number_to_index = {z: i for i, z in enumerate(self.atom_numbers)}

        self.data_r_max = r_max
        self.remove_self_loops = remove_self_loops

        assert sum(['auxdensity' in p for p in parts_to_load]) <= 1, 'Only one kind of auxdensity can be loaded.'

        if any(p.startswith('auxdensity') for p in parts_to_load):
            self.auxbasis = GTOBasis.from_basis_name(auxbasis, elements=type_names)

        if 'dm' in parts_to_load or 'fock' in parts_to_load or 'mo' in parts_to_load:
            self.aobasis = GTOBasis.from_basis_name(aobasis, elements=type_names)
            self.ao_prod_basis = GTOProductBasisHelper(self.aobasis)

    def __len__(self):
        return len(self.dataset)

    @lru_cache(maxsize=16)
    def __getitem__(self, idx):
        d = self.dataset[idx].copy()

        d['atom_coords'] = d['atom_coords']

        d['edge_index'] = compute_edge_index(d['atom_coords'], self.data_r_max, self.remove_self_loops)

        ret = {
            'z': torch.LongTensor([self.atom_number_to_index[n] for n in d['atom_number']]),
            'pos': torch.FloatTensor(d['atom_coords']),
            'net_charge': torch.LongTensor([int(d['net_charge'])]),
            'spin': torch.LongTensor([int(d['spin'])]),

            'edge_index': torch.LongTensor(d['edge_index']),
        }

        if any(p.startswith('auxdensity') for p in self.parts_to_load):
            auxdensity_key = 'aux_density_denfit'
            gtoaux = GTOAuxDensityHelper(d['atom_number'], self.auxbasis)
            auxdensity_by_element = gtoaux.split_ao_by_elements(
                gtoaux.transform_from_pyscf_to_std(d[auxdensity_key])
            )
            ret.update({
                'auxdensity': {k: torch.FloatTensor(t) for k, t in auxdensity_by_element.items()},
                'species_indices': {k: torch.IntTensor(t) for k, t in gtoaux.atom_indices_by_element.items()},
            })

        if 'dm' in self.parts_to_load:
            dm_diag_blocks, dm_diag_masks, dm_tril_blocks, dm_tril_masks, dm_tril_edge_index = self.ao_prod_basis.split_matrix_to_padded_blocks(
                d['atom_number'],
                self.ao_prod_basis.transform_from_pyscf_to_std(d['atom_number'], d['density_matrix']),
            )
            ret.update({
                'dm_diag_blocks': torch.FloatTensor(dm_diag_blocks),
                'dm_diag_masks': torch.BoolTensor(dm_diag_masks),
                'dm_tril_blocks': torch.FloatTensor(dm_tril_blocks),
                'dm_tril_masks': torch.BoolTensor(dm_tril_masks),
                'dm_tril_edge_index': torch.IntTensor(dm_tril_edge_index),
            })

        if 'fock' in self.parts_to_load:
            fock_diag_blocks, fock_diag_masks, fock_tril_blocks, fock_tril_masks, fock_tril_edge_index = self.ao_prod_basis.split_matrix_to_padded_blocks(
                d['atom_number'],
                self.ao_prod_basis.transform_from_pyscf_to_std(d['atom_number'], d['fock']),
            )
            ret.update({
                'fock_diag_blocks': torch.FloatTensor(fock_diag_blocks),
                'fock_diag_masks': torch.BoolTensor(fock_diag_masks),
                'fock_tril_blocks': torch.FloatTensor(fock_tril_blocks),
                'fock_tril_masks': torch.BoolTensor(fock_tril_masks),
                'fock_tril_edge_index': torch.IntTensor(fock_tril_edge_index),
            })

        return ret
