import numpy as np
import torch
import torch.nn.functional
from ase import build
from e3nn import o3
from e3nn.util import jit
from scipy.spatial.transform import Rotation as R

from mace import data, modules, tools
from mace.tools import torch_geometric

torch.set_default_dtype(torch.float64)
config = data.Configuration(
    atomic_numbers=np.array([8, 1, 1]),
    positions=np.array(
        [
            [0.0, -2.0, 0.0],
            [1.0, 0.0, 0.0],
            [0.0, 1.0, 0.0],
        ]
    ),
    properties={
        "forces": np.array(
            [
                [0.0, -1.3, 0.0],
                [1.0, 0.2, 0.0],
                [0.0, 1.1, 0.3],
            ]
        ),
        "energy": -1.5,
        "charges": np.array([-2.0, 1.0, 1.0]),
        "dipole": np.array([-1.5, 1.5, 2.0]),
    },
    property_weights={
        "forces": 1.0,
        "energy": 1.0,
        "charges": 1.0,
        "dipole": 1.0,
    },
)
# Created the rotated environment
rot = R.from_euler("z", 60, degrees=True).as_matrix()
positions_rotated = np.array(rot @ config.positions.T).T
config_rotated = data.Configuration(
    atomic_numbers=np.array([8, 1, 1]),
    positions=positions_rotated,
    properties={
        "forces": np.array(
            [
                [0.0, -1.3, 0.0],
                [1.0, 0.2, 0.0],
                [0.0, 1.1, 0.3],
            ]
        ),
        "energy": -1.5,
        "charges": np.array([-2.0, 1.0, 1.0]),
        "dipole": np.array([-1.5, 1.5, 2.0]),
    },
    property_weights={
        "forces": 1.0,
        "energy": 1.0,
        "charges": 1.0,
        "dipole": 1.0,
    },
)
table = tools.AtomicNumberTable([1, 8])
atomic_energies = np.array([1.0, 3.0], dtype=float)


def test_mace():
    # Create MACE model
    model_config = dict(
        r_max=5,
        num_bessel=8,
        num_polynomial_cutoff=6,
        max_ell=2,
        interaction_cls=modules.interaction_classes[
            "RealAgnosticResidualInteractionBlock"
        ],
        interaction_cls_first=modules.interaction_classes[
            "RealAgnosticResidualInteractionBlock"
        ],
        num_interactions=5,
        num_elements=2,
        hidden_irreps=o3.Irreps("32x0e + 32x1o"),
        MLP_irreps=o3.Irreps("16x0e"),
        gate=torch.nn.functional.silu,
        atomic_energies=atomic_energies,
        avg_num_neighbors=8,
        atomic_numbers=table.zs,
        correlation=3,
        radial_type="bessel",
    )
    model = modules.MACE(**model_config)
    model_compiled = jit.compile(model)

    atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0)
    atomic_data2 = data.AtomicData.from_config(
        config_rotated, z_table=table, cutoff=3.0
    )

    data_loader = torch_geometric.dataloader.DataLoader(
        dataset=[atomic_data, atomic_data2],
        batch_size=2,
        shuffle=True,
        drop_last=False,
    )
    batch = next(iter(data_loader))
    output1 = model(batch.to_dict(), training=True)
    output2 = model_compiled(batch.to_dict(), training=True)
    assert torch.allclose(output1["energy"][0], output2["energy"][0])
    assert torch.allclose(output2["energy"][0], output2["energy"][1])


def test_dipole_mace():
    # create dipole MACE model
    model_config = dict(
        r_max=5,
        num_bessel=8,
        num_polynomial_cutoff=5,
        max_ell=2,
        interaction_cls=modules.interaction_classes[
            "RealAgnosticResidualInteractionBlock"
        ],
        interaction_cls_first=modules.interaction_classes[
            "RealAgnosticResidualInteractionBlock"
        ],
        num_interactions=2,
        num_elements=2,
        hidden_irreps=o3.Irreps("16x0e + 16x1o + 16x2e"),
        MLP_irreps=o3.Irreps("16x0e"),
        gate=torch.nn.functional.silu,
        atomic_energies=None,
        avg_num_neighbors=3,
        atomic_numbers=table.zs,
        correlation=3,
        radial_type="gaussian",
    )
    model = modules.AtomicDipolesMACE(**model_config)

    atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0)
    atomic_data2 = data.AtomicData.from_config(
        config_rotated, z_table=table, cutoff=3.0
    )

    data_loader = torch_geometric.dataloader.DataLoader(
        dataset=[atomic_data, atomic_data2],
        batch_size=2,
        shuffle=False,
        drop_last=False,
    )
    batch = next(iter(data_loader))
    output = model(
        batch,
        training=True,
    )
    # sanity check of dipoles being the right shape
    assert output["dipole"][0].unsqueeze(0).shape == atomic_data.dipole.shape
    # test equivariance of output dipoles
    assert np.allclose(
        np.array(rot @ output["dipole"][0].detach().numpy()),
        output["dipole"][1].detach().numpy(),
    )


def test_energy_dipole_mace():
    # create dipole MACE model
    model_config = dict(
        r_max=5,
        num_bessel=8,
        num_polynomial_cutoff=5,
        max_ell=2,
        interaction_cls=modules.interaction_classes[
            "RealAgnosticResidualInteractionBlock"
        ],
        interaction_cls_first=modules.interaction_classes[
            "RealAgnosticResidualInteractionBlock"
        ],
        num_interactions=2,
        num_elements=2,
        hidden_irreps=o3.Irreps("16x0e + 16x1o + 16x2e"),
        MLP_irreps=o3.Irreps("16x0e"),
        gate=torch.nn.functional.silu,
        atomic_energies=atomic_energies,
        avg_num_neighbors=3,
        atomic_numbers=table.zs,
        correlation=3,
    )
    model = modules.EnergyDipolesMACE(**model_config)

    atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0)
    atomic_data2 = data.AtomicData.from_config(
        config_rotated, z_table=table, cutoff=3.0
    )

    data_loader = torch_geometric.dataloader.DataLoader(
        dataset=[atomic_data, atomic_data2],
        batch_size=2,
        shuffle=False,
        drop_last=False,
    )
    batch = next(iter(data_loader))
    output = model(
        batch,
        training=True,
    )
    # sanity check of dipoles being the right shape
    assert output["dipole"][0].unsqueeze(0).shape == atomic_data.dipole.shape
    # test energy is invariant
    assert torch.allclose(output["energy"][0], output["energy"][1])
    # test equivariance of output dipoles
    assert np.allclose(
        np.array(rot @ output["dipole"][0].detach().numpy()),
        output["dipole"][1].detach().numpy(),
    )


def test_mace_multi_reference():
    atomic_energies_multi = np.array([[1.0, 3.0], [0.0, 0.0]], dtype=float)
    model_config = dict(
        r_max=5,
        num_bessel=8,
        num_polynomial_cutoff=6,
        max_ell=3,
        interaction_cls=modules.interaction_classes[
            "RealAgnosticResidualInteractionBlock"
        ],
        interaction_cls_first=modules.interaction_classes[
            "RealAgnosticResidualInteractionBlock"
        ],
        num_interactions=2,
        num_elements=2,
        hidden_irreps=o3.Irreps("96x0e + 96x1o"),
        MLP_irreps=o3.Irreps("16x0e"),
        gate=torch.nn.functional.silu,
        atomic_energies=atomic_energies_multi,
        avg_num_neighbors=8,
        atomic_numbers=table.zs,
        distance_transform=True,
        pair_repulsion=True,
        correlation=3,
        heads=["Default", "dft"],
        # radial_type="chebyshev",
        atomic_inter_scale=[1.0, 1.0],
        atomic_inter_shift=[0.0, 0.1],
    )
    model = modules.ScaleShiftMACE(**model_config)
    model_compiled = jit.compile(model)
    config.head = "Default"
    config_rotated.head = "dft"
    atomic_data = data.AtomicData.from_config(
        config, z_table=table, cutoff=3.0, heads=["Default", "dft"]
    )
    atomic_data2 = data.AtomicData.from_config(
        config_rotated, z_table=table, cutoff=3.0, heads=["Default", "dft"]
    )

    data_loader = torch_geometric.dataloader.DataLoader(
        dataset=[atomic_data, atomic_data2],
        batch_size=2,
        shuffle=True,
        drop_last=False,
    )
    batch = next(iter(data_loader))
    output1 = model(batch.to_dict(), training=True)
    output2 = model_compiled(batch.to_dict(), training=True)
    assert torch.allclose(output1["energy"][0], output2["energy"][0])
    assert output2["energy"].shape[0] == 2


def test_atomic_virials_stresses():
    """
    Test that atomic virials and stresses sum to the total virials and stress.
    """
    # Set default dtype for reproducibility
    torch.set_default_dtype(torch.float64)

    # Create a periodic cell with ASE
    atoms = build.bulk("Si", "diamond", a=5.43)
    # Apply strain to ensure non-zero stress
    strain_tensor = np.eye(3) * 1.02  # 2% strain
    atoms.set_cell(np.dot(atoms.get_cell(), strain_tensor), scale_atoms=True)

    # Add forces and energy for completeness
    atoms.arrays["REF_forces"] = np.random.normal(0, 0.1, size=atoms.positions.shape)
    atoms.info["REF_energy"] = np.random.normal(0, 1)
    atoms.info["REF_stress"] = np.random.normal(0, 0.1, size=6)

    # Setup MACE model configuration
    stress_z_table = tools.AtomicNumberTable([14])  # Silicon
    stress_atomic_energies = np.array([0.0])

    model_config = dict(
        r_max=5.0,
        num_bessel=8,
        num_polynomial_cutoff=6,
        max_ell=2,
        interaction_cls=modules.interaction_classes[
            "RealAgnosticResidualInteractionBlock"
        ],
        interaction_cls_first=modules.interaction_classes[
            "RealAgnosticResidualInteractionBlock"
        ],
        num_interactions=3,
        num_elements=1,
        hidden_irreps=o3.Irreps("32x0e + 32x1o"),
        MLP_irreps=o3.Irreps("16x0e"),
        gate=torch.nn.functional.silu,
        atomic_energies=stress_atomic_energies,
        avg_num_neighbors=4.0,
        atomic_numbers=table.zs,
        correlation=3,
        atomic_inter_scale=1.0,
        atomic_inter_shift=0.0,
    )

    # Create the model
    model = modules.ScaleShiftMACE(**model_config)

    # Create atomic data
    atomic_data = data.AtomicData.from_config(
        data.config_from_atoms(
            atoms, key_specification=data.KeySpecification.from_defaults()
        ),
        z_table=stress_z_table,
        cutoff=5.0,
    )

    data_loader = torch_geometric.dataloader.DataLoader(
        dataset=[atomic_data],
        batch_size=2,
        shuffle=True,
        drop_last=False,
    )
    batch = next(iter(data_loader))
    batch_dict = batch.to_dict()

    # Run the model with compute_atomic_stresses=True
    output = model(
        batch_dict,
        compute_force=True,
        compute_virials=True,
        compute_stress=True,
        compute_atomic_stresses=True,
    )

    # Get total virials/stress and atomic virials/stresses
    total_virials = output["virials"]
    atomic_virials = output["atomic_virials"]
    total_stress = output["stress"]
    atomic_stresses = output["atomic_stresses"]

    # Test that atomic values are not None
    assert atomic_virials is not None, "Atomic virials were not computed"
    assert atomic_stresses is not None, "Atomic stresses were not computed"

    # Test shape of atomic values
    assert atomic_virials.shape[0] == len(atoms), "Wrong shape for atomic virials"
    assert atomic_virials.shape[1:] == (3, 3), "Atomic virials should be 3x3 matrices"
    assert atomic_stresses.shape[0] == len(atoms), "Wrong shape for atomic stresses"
    assert atomic_stresses.shape[1:] == (3, 3), "Atomic stresses should be 3x3 matrices"

    # Compute sum of atomic values
    summed_atomic_virials = torch.sum(atomic_virials, dim=0)
    summed_atomic_stresses = torch.sum(atomic_stresses, dim=0)

    # Test that sums match total values
    assert torch.allclose(
        summed_atomic_virials, total_virials.squeeze(0), atol=1e-6
    ), f"Sum of atomic virials {summed_atomic_virials} does not match total virials {total_virials.squeeze(0)}"

    assert torch.allclose(
        summed_atomic_stresses, total_stress.squeeze(0), atol=1e-6
    ), f"Sum of atomic stresses (normalized by volume) {summed_atomic_stresses} does not match total stress {total_stress.squeeze(0)}"
