import os
import torch
import shutil
from ase import Atoms
from ase.neighborlist import neighbor_list as ase_neighbor_list
from matscipy.neighbours import neighbour_list as msp_neighbor_list
from .base import Transform
from dirsync import sync
import numpy as np
from typing import Optional, Dict, List, Type, Any, Union

__all__ = [
    "ASENeighborList",
    "MatScipyNeighborList",
    "TorchNeighborList",
    "CountNeighbors",
    "CollectAtomTriples",
    "CachedNeighborList",
    "NeighborListTransform",
    "WrapPositions",
    "SkinNeighborList",
    "NeighborlistWrapper",
    "FilterNeighbors",
]

import schnetpack as spk
from schnetpack import properties
import fasteners


class CacheException(Exception):
    pass


class NeighborlistWrapper(Transform):
    """
    Wrapper class for neighbor lists. Using this wrapper allows to add multiple postprocessing steps to the neighbor
    list transform
    """

    def __init__(
        self,
        neighbor_list: Transform,
        nbh_postprocessing: Optional[List[torch.nn.Module]] = None,
    ):
        """
        Args:
            neighbor_list: the neighbor list to use
            nbh_postprocessing: post-processing transforms for manipulating the neighbor lists provided by neighbor_list
        """
        super().__init__()
        self.neighbor_list = neighbor_list
        self.nbh_postprocessing = nbh_postprocessing or []

    def forward(
        self,
        inputs: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:
        inputs = self.neighbor_list(inputs)
        for postprocess in self.nbh_postprocessing:
            inputs = postprocess(inputs)
        return inputs


class CachedNeighborList(Transform):
    """
    Dynamic caching of neighbor lists.
    This wraps a neighbor list and stores the results the first time it is called
    for a dataset entry with the pid provided by AtomsDataset. Particularly,
    for large systems, this speeds up training significantly.
    Note:
        The provided cache location should be unique to the used dataset. Otherwise,
        wrong neighborhoods will be provided. The caching location can be reused
        across multiple runs, by setting `cleanup_cache=False`.
    """

    is_preprocessor: bool = True
    is_postprocessor: bool = False

    def __init__(
        self,
        cache_path: str,
        neighbor_list: Transform,
        keep_cache: bool = False,
        cache_workdir: str = None,
    ):
        """
        Args:
            cache_path: Path of caching directory.
            neighbor_list: the neighbor list to use
            keep_cache: Keep cache at `cache_location` at the end of training, or copy built/updated cache there from
                `cache_workdir` (if set). A pre-existing cache at `cache_location` will not be deleted, while a
                 temporary cache at `cache_workdir` will always be removed.
            cache_workdir: If this is set, the cache will be build here, e.g. a cluster scratch space
                for faster performance. An existing cache at `cache_location` is copied here at the beginning of
                training, and afterwards (if `keep_cache=True`) the final cache is copied to `cache_workdir`.
        """
        super().__init__()
        self.neighbor_list = neighbor_list
        self.keep_cache = keep_cache
        self.cache_path = cache_path
        self.cache_workdir = cache_workdir
        self.preexisting_cache = os.path.exists(self.cache_path)
        self.has_tmp_workdir = cache_workdir is not None

        os.makedirs(cache_path, exist_ok=True)

        if self.has_tmp_workdir:
            # cache workdir should be empty to avoid loading nbh lists from earlier runs
            if os.path.exists(cache_workdir):
                raise CacheException("The provided `cache_workdir` already exists!")

            # copy existing nbh lists to cache workdir
            if self.preexisting_cache:
                shutil.copytree(cache_path, cache_workdir)
            self.cache_location = cache_workdir
        else:
            # use cache_location to store and load neighborlists
            self.cache_location = cache_path

    def forward(
        self,
        inputs: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:
        cache_file = os.path.join(
            self.cache_location, f"cache_{inputs[properties.idx][0]}.pt"
        )

        # try to read cached NBL
        try:
            data = torch.load(cache_file)
            inputs.update(data)
        except IOError:
            # acquire lock for caching
            lock = fasteners.InterProcessLock(
                os.path.join(
                    self.cache_location, f"cache_{inputs[properties.idx][0]}.lock"
                )
            )
            with lock:
                # retry reading, in case other process finished in the meantime
                try:
                    data = torch.load(cache_file)
                    inputs.update(data)
                except IOError:
                    # now it is save to calculate and cache
                    inputs = self.neighbor_list(inputs)
                    data = {
                        properties.idx_i: inputs[properties.idx_i],
                        properties.idx_j: inputs[properties.idx_j],
                        properties.offsets: inputs[properties.offsets],
                    }
                    torch.save(data, cache_file)
                except Exception as e:
                    print(e)
        return inputs

    def teardown(self):
        if not self.keep_cache and not self.preexisting_cache:
            try:
                shutil.rmtree(self.cache_path)
            except:
                pass

        if self.cache_workdir is not None:
            if self.keep_cache:
                try:
                    sync(self.cache_workdir, self.cache_path, "sync")
                except:
                    pass

            try:
                shutil.rmtree(self.cache_workdir)
            except:
                pass


class NeighborListTransform(Transform):
    """
    Base class for neighbor lists.
    """

    is_preprocessor: bool = True
    is_postprocessor: bool = False

    def __init__(
        self,
        cutoff: float,
    ):
        """
        Args:
            cutoff: Cutoff radius for neighbor search.
        """
        super().__init__()
        self._cutoff = cutoff

    def forward(
        self,
        inputs: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:
        Z = inputs[properties.Z]
        R = inputs[properties.R]
        cell = inputs[properties.cell].view(3, 3)
        pbc = inputs[properties.pbc]

        idx_i, idx_j, offset = self._build_neighbor_list(Z, R, cell, pbc, self._cutoff)
        inputs[properties.idx_i] = idx_i.detach()
        inputs[properties.idx_j] = idx_j.detach()
        inputs[properties.offsets] = offset
        return inputs

    def _build_neighbor_list(
        self,
        Z: torch.Tensor,
        positions: torch.Tensor,
        cell: torch.Tensor,
        pbc: torch.Tensor,
        cutoff: float,
    ):
        """Override with specific neighbor list implementation"""
        raise NotImplementedError


class ASENeighborList(NeighborListTransform):
    """
    Calculate neighbor list using ASE.
    """

    def _build_neighbor_list(self, Z, positions, cell, pbc, cutoff):
        at = Atoms(numbers=Z, positions=positions, cell=cell, pbc=pbc)

        idx_i, idx_j, S = ase_neighbor_list("ijS", at, cutoff, self_interaction=False)
        idx_i = torch.from_numpy(idx_i)
        idx_j = torch.from_numpy(idx_j)
        S = torch.from_numpy(S).to(dtype=positions.dtype)
        offset = torch.mm(S, cell)
        return idx_i, idx_j, offset


class MatScipyNeighborList(NeighborListTransform):
    """
    Neighborlist using the efficient implementation of the Matscipy package
    (https://github.com/libAtoms/matscipy).
    """

    def _build_neighbor_list(
        self, Z, positions, cell, pbc, cutoff, eps=1e-6, buffer=1.0
    ):
        at = Atoms(numbers=Z, positions=positions, cell=cell, pbc=pbc)

        # Add cell if none is present (volume = 0)
        if at.cell.volume < eps:
            # max values - min values along xyz augmented by small buffer for stability
            new_cell = np.ptp(at.positions, axis=0) + buffer
            # Set cell and center
            at.set_cell(new_cell, scale_atoms=False)
            at.center()

        # Compute neighborhood
        idx_i, idx_j, S = msp_neighbor_list("ijS", at, cutoff)
        idx_i = torch.from_numpy(idx_i).long()
        idx_j = torch.from_numpy(idx_j).long()
        S = torch.from_numpy(S).to(dtype=positions.dtype)
        offset = torch.mm(S, cell)

        return idx_i, idx_j, offset


class SkinNeighborList(Transform):
    """
    Neighbor list provider utilizing a cutoff skin for computational efficiency. Wrapper around neighbor list classes
    such as, e.g., ASENeighborList. Designed for use cases with gradual structural changes such ase MD simulations
    and structure relaxations.

    Note:
        - Not meant to be used for training, since the shuffling of training data results in large
          structural deviations between subsequent training samples.
        - Not transferable between different molecule conformations or varying atom indexing.
    """

    is_preprocessor: bool = True
    is_postprocessor: bool = False

    def __init__(
        self,
        neighbor_list: Transform,
        nbh_postprocessing: Optional[List[torch.nn.Module]] = None,
        cutoff_skin: float = 0.3,
    ):
        """
        Args:
            neighbor_list: the neighbor list to use
            nbh_postprocessing: post-processing transforms for manipulating the neighbor lists provided by neighbor_list
            cutoff_skin: float
                If no atom has moved more than the skin-distance since the neighbor list has been updated the last time,
                then the neighbor list is reused. This will save some expensive rebuilds of the list.
                Note:
                    Please choose a sufficiently large cutoff_skin value to ensure that between two subsequent samples
                    no atom can penetrate through the skin into the cutoff sphere of another atom if it is not in the
                    neighbor list of that atom.
        """

        super().__init__()

        self.neighbor_list = neighbor_list
        self.cutoff = neighbor_list._cutoff
        self.cutoff_skin = cutoff_skin
        self.neighbor_list._cutoff = self.cutoff + cutoff_skin
        self.nbh_postprocessing = nbh_postprocessing or []
        self.distance_calculator = spk.atomistic.PairwiseDistances()
        self.previous_inputs = {}

    # @timeit
    def forward(
        self,
        inputs: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:

        update_required, inputs = self._update(inputs)
        inputs = self.distance_calculator(inputs)
        inputs = self._remove_neighbors_in_skin(inputs)

        return inputs

    def reset(self):
        self.previous_inputs = {}

    def _remove_neighbors_in_skin(
        self,
        inputs: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:

        Rij = inputs[properties.Rij]
        idx_i = inputs[properties.idx_i]
        idx_j = inputs[properties.idx_j]
        offsets = inputs[properties.offsets]

        rij = torch.norm(inputs[properties.Rij], dim=-1)
        cidx = torch.nonzero(rij <= self.cutoff).squeeze(-1)

        inputs[properties.Rij] = Rij[cidx]
        inputs[properties.idx_i] = idx_i[cidx]
        inputs[properties.idx_j] = idx_j[cidx]
        inputs[properties.offsets] = offsets[cidx]

        return inputs

    def _update(self, inputs):
        """Make sure the list is up to date."""

        # get sample index
        sample_idx = inputs[properties.idx].item()

        # check if previous neighbor list exists and make sure that this is not the first update step
        if sample_idx in self.previous_inputs.keys():
            # load previous inputs
            previous_inputs = self.previous_inputs[sample_idx]
            # extract previous structure
            previous_positions = np.array(previous_inputs[properties.R], copy=True)
            previous_cell = np.array(
                previous_inputs[properties.cell].view(3, 3), copy=True
            )
            previous_pbc = np.array(previous_inputs[properties.pbc], copy=True)
            # extract current structure
            positions = inputs[properties.R]
            cell = inputs[properties.cell].view(3, 3)
            pbc = inputs[properties.pbc]
            # check if structure change is sufficiently small to reuse previous neighbor list
            if (
                (previous_pbc == pbc.numpy()).any()
                and (previous_cell == cell.numpy()).any()
                and ((previous_positions - positions.numpy()) ** 2).sum(1).max()
                < 0.25 * self.cutoff_skin**2
            ):
                # reuse previous neighbor list
                inputs[properties.idx_i] = (
                    previous_inputs[properties.idx_i].detach().clone()
                )
                inputs[properties.idx_j] = (
                    previous_inputs[properties.idx_j].detach().clone()
                )
                inputs[properties.offsets] = (
                    previous_inputs[properties.offsets].detach().clone()
                )
                return False, inputs

        # build new neighbor list
        inputs = self._build(inputs)
        return True, inputs

    def _build(self, inputs):

        # apply all transforms to obtain new neighbor list
        inputs = self.neighbor_list(inputs)
        for postprocess in self.nbh_postprocessing:
            inputs = postprocess(inputs)

        # store new reference conformation and remove old one
        sample_idx = inputs[properties.idx].item()
        stored_inputs = {
            properties.R: inputs[properties.R].clone(),
            properties.cell: inputs[properties.cell].clone(),
            properties.pbc: inputs[properties.pbc].clone(),
            properties.idx_i: inputs[properties.idx_i].clone(),
            properties.idx_j: inputs[properties.idx_j].clone(),
            properties.offsets: inputs[properties.offsets].clone(),
        }
        self.previous_inputs.update({sample_idx: stored_inputs})

        return inputs


class TorchNeighborList(NeighborListTransform):
    """
    Environment provider making use of neighbor lists as implemented in TorchAni
    (https://github.com/aiqm/torchani/blob/master/torchani/aev.py).
    Supports cutoffs and PBCs and can be performed on either CPU or GPU.
    """

    def _build_neighbor_list(self, Z, positions, cell, pbc, cutoff):
        # Check if shifts are needed for periodic boundary conditions
        if torch.all(pbc == 0):
            shifts = torch.zeros(0, 3, device=cell.device, dtype=torch.long)
        else:
            shifts = self._get_shifts(cell, pbc, cutoff)
        idx_i, idx_j, offset = self._get_neighbor_pairs(positions, cell, shifts, cutoff)

        # Create bidirectional id arrays, similar to what the ASE neighbor_list returns
        bi_idx_i = torch.cat((idx_i, idx_j), dim=0)
        bi_idx_j = torch.cat((idx_j, idx_i), dim=0)

        # Sort along first dimension (necessary for atom-wise pooling)
        sorted_idx = torch.argsort(bi_idx_i)
        idx_i = bi_idx_i[sorted_idx]
        idx_j = bi_idx_j[sorted_idx]

        bi_offset = torch.cat((-offset, offset), dim=0)
        offset = bi_offset[sorted_idx]
        offset = torch.mm(offset.to(cell.dtype), cell)

        return idx_i, idx_j, offset

    def _get_neighbor_pairs(self, positions, cell, shifts, cutoff):
        """Compute pairs of atoms that are neighbors
        Copyright 2018- Xiang Gao and other ANI developers
        (https://github.com/aiqm/torchani/blob/master/torchani/aev.py)
        Arguments:
            positions (:class:`torch.Tensor`): tensor of shape
                (molecules, atoms, 3) for atom coordinates.
            cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three vectors
                defining unit cell: tensor([[x1, y1, z1], [x2, y2, z2], [x3, y3, z3]])
            shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts
        """
        num_atoms = positions.shape[0]
        all_atoms = torch.arange(num_atoms, device=cell.device)

        # 1) Central cell
        pi_center, pj_center = torch.combinations(all_atoms).unbind(-1)
        shifts_center = shifts.new_zeros(pi_center.shape[0], 3)

        # 2) cells with shifts
        # shape convention (shift index, molecule index, atom index, 3)
        num_shifts = shifts.shape[0]
        all_shifts = torch.arange(num_shifts, device=cell.device)
        shift_index, pi, pj = torch.cartesian_prod(
            all_shifts, all_atoms, all_atoms
        ).unbind(-1)
        shifts_outside = shifts.index_select(0, shift_index)

        # 3) combine results for all cells
        shifts_all = torch.cat([shifts_center, shifts_outside])
        pi_all = torch.cat([pi_center, pi])
        pj_all = torch.cat([pj_center, pj])

        # 4) Compute shifts and distance vectors
        shift_values = torch.mm(shifts_all.to(cell.dtype), cell)
        Rij_all = positions[pi_all] - positions[pj_all] + shift_values

        # 5) Compute distances, and find all pairs within cutoff
        distances = torch.norm(Rij_all, dim=1)
        in_cutoff = torch.nonzero(distances < cutoff, as_tuple=False)

        # 6) Reduce tensors to relevant components
        pair_index = in_cutoff.squeeze()
        atom_index_i = pi_all[pair_index]
        atom_index_j = pj_all[pair_index]
        offsets = shifts_all[pair_index]

        return atom_index_i, atom_index_j, offsets

    def _get_shifts(self, cell, pbc, cutoff):
        """Compute the shifts of unit cell along the given cell vectors to make it
        large enough to contain all pairs of neighbor atoms with PBC under
        consideration.
        Copyright 2018- Xiang Gao and other ANI developers
        (https://github.com/aiqm/torchani/blob/master/torchani/aev.py)
        Arguments:
            cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three
            vectors defining unit cell: tensor([[x1, y1, z1], [x2, y2, z2], [x3, y3, z3]])
            pbc (:class:`torch.Tensor`): boolean vector of size 3 storing
                if pbc is enabled for that direction.
        Returns:
            :class:`torch.Tensor`: long tensor of shifts. the center cell and
                symmetric cells are not included.
        """
        reciprocal_cell = cell.inverse().t()
        inverse_lengths = torch.norm(reciprocal_cell, dim=1)

        num_repeats = torch.ceil(cutoff * inverse_lengths).long()
        num_repeats = torch.where(
            pbc, num_repeats, torch.Tensor([0], device=cell.device).long()
        )

        r1 = torch.arange(1, num_repeats[0] + 1, device=cell.device)
        r2 = torch.arange(1, num_repeats[1] + 1, device=cell.device)
        r3 = torch.arange(1, num_repeats[2] + 1, device=cell.device)
        o = torch.zeros(1, dtype=torch.long, device=cell.device)

        return torch.cat(
            [
                torch.cartesian_prod(r1, r2, r3),
                torch.cartesian_prod(r1, r2, o),
                torch.cartesian_prod(r1, r2, -r3),
                torch.cartesian_prod(r1, o, r3),
                torch.cartesian_prod(r1, o, o),
                torch.cartesian_prod(r1, o, -r3),
                torch.cartesian_prod(r1, -r2, r3),
                torch.cartesian_prod(r1, -r2, o),
                torch.cartesian_prod(r1, -r2, -r3),
                torch.cartesian_prod(o, r2, r3),
                torch.cartesian_prod(o, r2, o),
                torch.cartesian_prod(o, r2, -r3),
                torch.cartesian_prod(o, o, r3),
            ]
        )


class FilterNeighbors(Transform):
    """
    Filter out all neighbor list indices corresponding to interactions between a set of atoms. This set of atoms must
    be specified in the input data.
    """

    def __init__(
        self, selection_name: str, additional_inputs: Dict[str, torch.Tensor] = None
    ):
        """
        Args:
            selection_name (str): key in the input data corresponding to the set of atoms between which no interactions
                should be considered.
            additional_inputs (dict): additional inputs required for this transform. When setting up the AtomsConverter,
                those additional inputs will be stored to the input batch.
        """
        self.selection_name = selection_name
        self.additional_inputs = additional_inputs or {}
        super().__init__()

    def forward(
        self,
        inputs: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:

        n_neighbors = inputs[properties.idx_i].shape[0]
        slab_indices = inputs[self.selection_name].tolist()
        kept_nbh_indices = []
        for nbh_idx in range(n_neighbors):
            i = inputs[properties.idx_i][nbh_idx].item()
            j = inputs[properties.idx_j][nbh_idx].item()
            if i not in slab_indices or j not in slab_indices:
                kept_nbh_indices.append(nbh_idx)

        inputs[properties.idx_i] = inputs[properties.idx_i][kept_nbh_indices]
        inputs[properties.idx_j] = inputs[properties.idx_j][kept_nbh_indices]
        inputs[properties.offsets] = inputs[properties.offsets][kept_nbh_indices]

        return inputs


class CollectAtomTriples(Transform):
    """
    Generate the index tensors for all triples between atoms within the cutoff shell.
    """

    is_preprocessor: bool = True
    is_postprocessor: bool = False

    def forward(
        self,
        inputs: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:
        """
        Using the neighbors contained within the cutoff shell, generate all unique pairs of neighbors and convert
        them to index arrays. Applied to the neighbor arrays, these arrays generate the indices involved in the atom
        triples.
        E.g.:
            idx_j[idx_j_triples] -> j atom in triple
            idx_j[idx_k_triples] -> k atom in triple
            Rij[idx_j_triples] -> Rij vector in triple
            Rij[idx_k_triples] -> Rik vector in triple
        """
        idx_i = inputs[properties.idx_i]

        _, n_neighbors = torch.unique_consecutive(idx_i, return_counts=True)

        offset = 0
        idx_i_triples = ()
        idx_jk_triples = ()
        for idx in range(n_neighbors.shape[0]):
            triples = torch.combinations(
                torch.arange(offset, offset + n_neighbors[idx]), r=2
            )
            idx_i_triples += (torch.ones(triples.shape[0], dtype=torch.long) * idx,)
            idx_jk_triples += (triples,)
            offset += n_neighbors[idx]

        idx_i_triples = torch.cat(idx_i_triples)

        idx_jk_triples = torch.cat(idx_jk_triples)
        idx_j_triples, idx_k_triples = idx_jk_triples.split(1, dim=-1)

        inputs[properties.idx_i_triples] = idx_i_triples
        inputs[properties.idx_j_triples] = idx_j_triples.squeeze(-1)
        inputs[properties.idx_k_triples] = idx_k_triples.squeeze(-1)
        return inputs


class CountNeighbors(Transform):
    """
    Store the number of neighbors for each atom
    """

    is_preprocessor: bool = True
    is_postprocessor: bool = False

    def __init__(self, sorted: bool = True):
        """
        Args:
            sorted: Set to false if chosen neighbor list yields unsorted center indices (idx_i).
        """
        super(CountNeighbors, self).__init__()
        self.sorted = sorted

    def forward(
        self,
        inputs: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:
        idx_i = inputs[properties.idx_i]

        if self.sorted:
            _, n_nbh = torch.unique_consecutive(idx_i, return_counts=True)
        else:
            _, n_nbh = torch.unique(idx_i, return_counts=True)

        inputs[properties.n_nbh] = n_nbh
        return inputs


class WrapPositions(Transform):
    """
    Wrap atom positions into periodic cell. This routine requires a non-zero cell.
    The cell center of the inverse cell is set to (0.5, 0.5, 0.5).
    """

    is_preprocessor: bool = True
    is_postprocessor: bool = False

    def __init__(self, eps: float = 1e-6):
        """
        Args:
            eps (float): small offset for numerical stability.
        """
        super().__init__()
        self.eps = eps

    def forward(
        self,
        inputs: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:
        R = inputs[properties.R]
        cell = inputs[properties.cell].view(3, 3)
        pbc = inputs[properties.pbc]

        inverse_cell = torch.inverse(cell)
        inv_positions = torch.sum(R[..., None] * inverse_cell[None, ...], dim=1)

        periodic = torch.masked_select(inv_positions, pbc[None, ...])

        # Apply periodic boundary conditions (with small buffer)
        periodic = periodic + self.eps
        periodic = periodic % 1.0
        periodic = periodic - self.eps

        # Update fractional coordinates
        inv_positions.masked_scatter_(pbc[None, ...], periodic)

        # Convert to positions
        R_wrapped = torch.sum(inv_positions[..., None] * cell[None, ...], dim=1)

        inputs[properties.R] = R_wrapped

        return inputs
