from copy import deepcopy
from pathlib import Path

import ase.build
import h5py
import numpy as np
import torch

from mace.data import (
    AtomicData,
    Configuration,
    HDF5Dataset,
    config_from_atoms,
    get_neighborhood,
    save_configurations_as_HDF5,
)
from mace.tools import AtomicNumberTable, torch_geometric

mace_path = Path(__file__).parent.parent


class TestAtomicData:
    config = Configuration(
        atomic_numbers=np.array([8, 1, 1]),
        positions=np.array(
            [
                [0.0, -2.0, 0.0],
                [1.0, 0.0, 0.0],
                [0.0, 1.0, 0.0],
            ]
        ),
        properties={
            "forces": np.array(
                [
                    [0.0, -1.3, 0.0],
                    [1.0, 0.2, 0.0],
                    [0.0, 1.1, 0.3],
                ]
            ),
            "energy": -1.5,
        },
        property_weights={
            "forces": 1.0,
            "energy": 1.0,
        },
    )
    config_2 = deepcopy(config)
    config_2.positions = config.positions + 0.01

    table = AtomicNumberTable([1, 8])

    def test_atomic_data(self):
        data = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0)

        assert data.edge_index.shape == (2, 4)
        assert data.forces.shape == (3, 3)
        assert data.node_attrs.shape == (3, 2)

    def test_data_loader(self):
        data1 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0)
        data2 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0)

        data_loader = torch_geometric.dataloader.DataLoader(
            dataset=[data1, data2],
            batch_size=2,
            shuffle=True,
            drop_last=False,
        )

        for batch in data_loader:
            assert batch.batch.shape == (6,)
            assert batch.edge_index.shape == (2, 8)
            assert batch.shifts.shape == (8, 3)
            assert batch.positions.shape == (6, 3)
            assert batch.node_attrs.shape == (6, 2)
            assert batch.energy.shape == (2,)
            assert batch.forces.shape == (6, 3)

    def test_to_atomic_data_dict(self):
        data1 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0)
        data2 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0)

        data_loader = torch_geometric.dataloader.DataLoader(
            dataset=[data1, data2],
            batch_size=2,
            shuffle=True,
            drop_last=False,
        )
        for batch in data_loader:
            batch_dict = batch.to_dict()
            assert batch_dict["batch"].shape == (6,)
            assert batch_dict["edge_index"].shape == (2, 8)
            assert batch_dict["shifts"].shape == (8, 3)
            assert batch_dict["positions"].shape == (6, 3)
            assert batch_dict["node_attrs"].shape == (6, 2)
            assert batch_dict["energy"].shape == (2,)
            assert batch_dict["forces"].shape == (6, 3)

    def test_hdf5_dataloader(self):
        datasets = [self.config, self.config_2] * 5
        # get path of the mace package
        with h5py.File(str(mace_path) + "test.h5", "w") as f:
            save_configurations_as_HDF5(datasets, 0, f)
        train_dataset = HDF5Dataset(
            str(mace_path) + "test.h5", z_table=self.table, r_max=3.0
        )
        train_loader = torch_geometric.dataloader.DataLoader(
            dataset=train_dataset,
            batch_size=2,
            shuffle=False,
            drop_last=False,
        )
        batch_count = 0
        for batch in train_loader:
            batch_count += 1
            assert batch.batch.shape == (6,)
            assert batch.edge_index.shape == (2, 8)
            assert batch.shifts.shape == (8, 3)
            assert batch.positions.shape == (6, 3)
            assert batch.node_attrs.shape == (6, 2)
            assert batch.energy.shape == (2,)
            assert batch.forces.shape == (6, 3)
        print(batch_count, len(train_loader), len(train_dataset))
        assert batch_count == len(train_loader) == len(train_dataset) / 2
        train_loader_direct = torch_geometric.dataloader.DataLoader(
            dataset=[
                AtomicData.from_config(config, z_table=self.table, cutoff=3.0)
                for config in datasets
            ],
            batch_size=2,
            shuffle=False,
            drop_last=False,
        )
        for batch_direct, batch in zip(train_loader_direct, train_loader):
            assert torch.all(batch_direct.edge_index == batch.edge_index)
            assert torch.all(batch_direct.shifts == batch.shifts)
            assert torch.all(batch_direct.positions == batch.positions)
            assert torch.all(batch_direct.node_attrs == batch.node_attrs)
            assert torch.all(batch_direct.energy == batch.energy)
            assert torch.all(batch_direct.forces == batch.forces)


class TestNeighborhood:
    def test_basic(self):
        positions = np.array(
            [
                [-1.0, 0.0, 0.0],
                [+0.0, 0.0, 0.0],
                [+1.0, 0.0, 0.0],
            ]
        )

        indices, shifts, unit_shifts, _ = get_neighborhood(positions, cutoff=1.5)
        assert indices.shape == (2, 4)
        assert shifts.shape == (4, 3)
        assert unit_shifts.shape == (4, 3)

    def test_signs(self):
        positions = np.array(
            [
                [+0.5, 0.5, 0.0],
                [+1.0, 1.0, 0.0],
            ]
        )

        cell = np.array([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
        edge_index, shifts, unit_shifts, _ = get_neighborhood(
            positions, cutoff=3.5, pbc=(True, False, False), cell=cell
        )
        num_edges = 10
        assert edge_index.shape == (2, num_edges)
        assert shifts.shape == (num_edges, 3)
        assert unit_shifts.shape == (num_edges, 3)


# Based on mir-group/nequip
def test_periodic_edge():
    atoms = ase.build.bulk("Cu", "fcc")
    dist = np.linalg.norm(atoms.cell[0]).item()
    config = config_from_atoms(atoms)
    edge_index, shifts, _, _ = get_neighborhood(
        config.positions, cutoff=1.05 * dist, pbc=(True, True, True), cell=config.cell
    )
    sender, receiver = edge_index
    vectors = (
        config.positions[receiver] - config.positions[sender] + shifts
    )  # [n_edges, 3]
    assert vectors.shape == (12, 3)  # 12 neighbors in close-packed bulk
    assert np.allclose(
        np.linalg.norm(vectors, axis=-1),
        dist,
    )


def test_half_periodic():
    atoms = ase.build.fcc111("Al", size=(3, 3, 1), vacuum=0.0)
    assert all(atoms.pbc == (True, True, False))
    config = config_from_atoms(atoms)  # first shell dist is 2.864A
    edge_index, shifts, _, _ = get_neighborhood(
        config.positions, cutoff=2.9, pbc=(True, True, False), cell=config.cell
    )
    sender, receiver = edge_index
    vectors = (
        config.positions[receiver] - config.positions[sender] + shifts
    )  # [n_edges, 3]
    # Check number of neighbors:
    _, neighbor_count = np.unique(edge_index[0], return_counts=True)
    assert (neighbor_count == 6).all()  # 6 neighbors
    # Check not periodic in z
    assert np.allclose(
        vectors[:, 2],
        np.zeros(vectors.shape[0]),
    )
