"""Compute distances between a pair of materials."""

import amd
import numpy as np
from pymatgen.analysis.structure_matcher import StructureMatcher

from xgemval.crystal import Crystal


def d_comp(xtal_1: Crystal, xtal_2: Crystal) -> float:
	"""Compute the composition distance between two crystals.

	Args:
		xtal_1 (Crystal): First crystal.
		xtal_2 (Crystal): Second crystal.

	Returns:
		float: Composition distance.
	"""
	emb_1 = xtal_1.get_composition()
	emb_2 = xtal_2.get_composition()
	return 0.0 if emb_1 == emb_2 else 1.0


def d_wyckoff(xtal_1: Crystal, xtal_2: Crystal) -> float:
	"""Compute the Wyckoff distance between two crystals.

	Args:
		xtal_1 (Crystal): First crystal.
		xtal_2 (Crystal): Second crystal.

	Returns:
		float: Wyckoff distance.
	"""
	emb_1 = xtal_1.get_wyckoff()
	emb_2 = xtal_2.get_wyckoff()
	if isinstance(emb_1, str):
		print(f"Failed to get Wyckoff representation of xtal_1. Error message: {emb_1}")
		return -1.0
	elif isinstance(emb_2, str):
		print(f"Failed to get Wyckoff representation of xtal_2. Error message: {emb_2}")
		return -1.0
	return 0.0 if emb_1 == emb_2 else 1.0


def d_smat(xtal_1: Crystal, xtal_2: Crystal) -> float:
	"""Compute the StructureMatcher distance between two crystals.

	Args:
		xtal_1 (Crystal): First crystal.
		xtal_2 (Crystal): Second crystal.

	Returns:
		float: SMAT distance.
	"""
	emb_1 = xtal_1.get_structure()
	emb_2 = xtal_2.get_structure()
	matcher = StructureMatcher()
	return 0.0 if matcher.fit(emb_1, emb_2) else 1.0


def d_magpie(xtal_1: Crystal, xtal_2: Crystal) -> float:
	"""Compute the Magpie distance between two crystals.

	Args:
		xtal_1 (Crystal): First crystal.
		xtal_2 (Crystal): Second crystal.

	Returns:
		float: Magpie distance.
	"""
	emb_1 = np.array(xtal_1.get_magpie())
	emb_2 = np.array(xtal_2.get_magpie())
	return np.sqrt(np.sum((emb_1 - emb_2) ** 2)).item()


def d_pdd(xtal_1: Crystal, xtal_2: Crystal) -> float:
	"""Compute the PDD distance between two crystals.

	Args:
		xtal_1 (Crystal): First crystal.
		xtal_2 (Crystal): Second crystal.

	Returns:
		float: PDD distance.
	"""
	pset_1 = xtal_1.get_periodicset()
	pset_2 = xtal_2.get_periodicset()
	if isinstance(pset_1, str):
		print(
			f"Failed to get the periodicset representation of xtal_1. Error message: {pset_1}"
		)
		return -1.0
	else:
		emb_1 = amd.PDD(pset_1, k=100)
	if isinstance(pset_2, str):
		print(
			f"Failed to get the periodicset representation of xtal_2. Error message: {pset_2}"
		)
		return -1.0
	else:
		emb_2 = amd.PDD(pset_2, k=100)
	return amd.PDD_cdist([emb_1], [emb_2])[0][0].item()


def d_amd(xtal_1: Crystal, xtal_2: Crystal) -> float:
	"""Compute the AMD distance between two crystals.

	Args:
		xtal_1 (Crystal): First crystal.
		xtal_2 (Crystal): Second crystal.

	Returns:
		float: AMD distance.
	"""
	pset_1 = xtal_1.get_periodicset()
	pset_2 = xtal_2.get_periodicset()
	if isinstance(pset_1, str):
		print(
			f"Failed to get the periodicset representation of xtal_1. Error message: {pset_1}"
		)
		return -1.0
	else:
		emb_1 = amd.AMD(pset_1, k=100)
	if isinstance(pset_2, str):
		print(
			f"Failed to get the periodicset representation of xtal_2. Error message: {pset_2}"
		)
		return -1.0
	else:
		emb_2 = amd.AMD(pset_2, k=100)
	return amd.AMD_cdist([emb_1], [emb_2])[0][0].item()
