"""Crystal class."""

from collections import Counter

import numpy as np
from amd import PeriodicSet, periodicset_from_pymatgen_structure
from ase import Atoms
from matminer.featurizers.composition.composite import ElementProperty
from pymatgen.core import Composition, Lattice, Structure
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer


class Crystal:
	"""Crystal class to represent a crystal structure."""

	def __init__(
		self,
		frac_coords: list[float],
		atom_types: list[int],
		lengths: list[float],
		angles: list[float],
	) -> None:
		"""Initialize the Crystal object.

		Args:
			frac_coords (list[float]): Fractional coordinates of atoms.
			atom_types (list[int]): Types of atoms in the crystal.
			lengths (list[float]): Lattice lengths.
			angles (list[float]): Lattice angles.
		"""
		self.frac_coords = frac_coords  # (num_atoms, 3)
		self.atom_types = atom_types  # (num_atoms, )
		self.lengths = lengths  # (3, )
		self.angles = angles  # (3, )

	def get_composition(self) -> tuple[tuple[int, int]]:
		"""Get the composition of the crystal.

		Returns:
			tuple: A tuple containing elements and their counts (divided bg gcd).
		"""
		elem_counter = Counter(self.atom_types)
		composition_unnorm = [
			(elem, elem_counter[elem]) for elem in sorted(elem_counter.keys())
		]
		gcd = np.gcd.reduce([count for _, count in composition_unnorm]).item()
		composition = tuple((elem, count // gcd) for elem, count in composition_unnorm)
		return composition

	def get_composition_pymatgen(self) -> Composition:
		"""Get the pymatgen composition of the crystal.

		Returns:
			Composition: Pymatgen Composition object.
		"""
		elem_counter = Counter(self.atom_types)
		composition = Composition(elem_counter)
		return composition

	def get_ase_atoms(self) -> Atoms:
		"""Get the ASE Atoms object of the crystal.

		Returns:
			Atoms: ASE Atoms object.
		"""
		return AseAtomsAdaptor.get_atoms(self.get_structure())

	def get_wyckoff(self) -> tuple[int, tuple[str]] | str:
		"""Get the Wyckoff representation of the crystal.

		Returns:
			tuple | str: A tuple containing the space group number and a tuple of
				Wyckoff letters, or an error message.
		"""
		structure = Structure(
			lattice=Lattice.from_dict(
				{
					"a": self.lengths[0],
					"b": self.lengths[1],
					"c": self.lengths[2],
					"alpha": self.angles[0],
					"beta": self.angles[1],
					"gamma": self.angles[2],
				}
			),
			species=self.atom_types,
			coords=self.frac_coords,
		)
		try:
			sga = SpacegroupAnalyzer(structure)
			sym = sga.get_symmetrized_structure()
			sg = sga.get_space_group_number()
			wyckoff_letters = sorted(
				[x[-1] for x in sym.wyckoff_symbols]
			)  # Don't use sym.wyckoff_letters
			return sg, tuple(wyckoff_letters)
		except Exception as e:
			return str(e)

	def get_structure(self) -> Structure:
		"""Get the pymatgen Structure object of the crystal.

		Returns:
			Structure: Pymatgen Structure object.
		"""
		return Structure(
			lattice=Lattice.from_dict(
				{
					"a": self.lengths[0],
					"b": self.lengths[1],
					"c": self.lengths[2],
					"alpha": self.angles[0],
					"beta": self.angles[1],
					"gamma": self.angles[2],
				}
			),
			species=self.atom_types,
			coords=self.frac_coords,
		)

	def get_magpie(self) -> list[float]:
		"""Get the magpie embedding of the crystal.

		Returns:
			list[float]: Magpie feature vector of the crystal.
		"""
		elem_counter = Counter(self.atom_types)
		composition = Composition(elem_counter)
		if not hasattr(self, "featurizer"):
			self.featurizer = ElementProperty.from_preset("magpie", impute_nan=True)
		feature = self.featurizer.featurize(composition)
		return [float(x) for x in feature]

	def get_periodicset(self) -> PeriodicSet | str:
		"""Get the AMD PeriodicSet of the crystal.

		Returns:
			PeriodicSet | str: AMD PeriodicSet object or an error message.
		"""
		structure = self.get_structure()
		try:
			return periodicset_from_pymatgen_structure(structure)
		except Exception as e:
			return str(e)
