"""Convert generated crystals to pymatgen structures."""

from typing import Any, Literal

from pandas import DataFrame
from pymatgen.core import Structure
from pymatgen.io.cif import CifParser

from xgemval.crystal import Crystal


def structure2crystal(
	structure: Structure,
) -> Crystal:
	"""Convert a pymatgen Structure to a Crystal object.

	Args:
		structure (Structure): Pymatgen structure.

	Returns:
		Crystal: Converted crystal structure.
	"""
	return Crystal(
		frac_coords=structure.frac_coords.tolist(),
		atom_types=[site.specie.number for site in structure.sites],
		lengths=list(structure.lattice.abc),
		angles=list(structure.lattice.angles),
	)


def convert_cdvae(
	gen_xtal_raw: dict,
) -> list[Crystal]:
	"""Convert generated crystals from CDVAE format to Crystal objects.

	Args:
		gen_xtal_raw (dict): Generated crystal structures.

	Returns:
		list[Crystal]: Converted crystal structures.
	"""
	all_frac_coords = gen_xtal_raw["frac_coords"][0]  # (99136, 3)
	all_num_atoms = gen_xtal_raw["num_atoms"][0]  # (10000, )
	all_atom_types = gen_xtal_raw["atom_types"][0]  # (99136, )
	all_lengths = gen_xtal_raw["lengths"][0]  # (10000, 3)
	all_angles = gen_xtal_raw["angles"][0]  # (10000, 3)

	start_idx = 0
	gen_xtal = []
	for batch_idx, num_atom in enumerate(all_num_atoms.tolist()):
		gen_xtal.append(
			Crystal(
				frac_coords=all_frac_coords.narrow(0, start_idx, num_atom).tolist(),
				atom_types=all_atom_types.narrow(0, start_idx, num_atom).tolist(),
				lengths=all_lengths[batch_idx].tolist(),
				angles=all_angles[batch_idx].tolist(),
			)
		)
		start_idx = start_idx + num_atom
	return gen_xtal


def convert_diffcsp(
	gen_xtal_raw: list[dict],
) -> list[Crystal]:
	"""Convert generated crystals from DiffCSP format to Crystal objects.

	Args:
		gen_xtal_raw (list[dict]): Generated crystal structures.

	Returns:
		list[Crystal]: Converted crystal structures.
	"""
	gen_xtal = []
	for xtal in gen_xtal_raw:
		structure = Structure.from_dict(xtal)
		gen_xtal.append(structure2crystal(structure))
	return gen_xtal


def convert_diffcsppp(
	gen_xtal_raw: dict,
) -> list[Crystal]:
	"""Convert generated crystals from DiffCSP++ format to Crystal objects.

	Args:
		gen_xtal_raw (dict): Generated crystal structures.

	Returns:
		list[Crystal]: Converted crystal structures.
	"""
	all_frac_coords = gen_xtal_raw["frac_coords"]  # (189427, 3)
	all_num_atoms = gen_xtal_raw["num_atoms"]  # (10000, )
	all_atom_types = gen_xtal_raw["atom_types"]  # (189427, )
	all_lengths = gen_xtal_raw["lengths"]  # (10000, 3)
	all_angles = gen_xtal_raw["angles"]  # (10000, 3)

	start_idx = 0
	gen_xtal = []
	for batch_idx, num_atom in enumerate(all_num_atoms.tolist()):
		gen_xtal.append(
			Crystal(
				frac_coords=all_frac_coords.narrow(0, start_idx, num_atom).tolist(),
				atom_types=all_atom_types.narrow(0, start_idx, num_atom).tolist(),
				lengths=all_lengths[batch_idx].tolist(),
				angles=all_angles[batch_idx].tolist(),
			)
		)
		start_idx = start_idx + num_atom
	return gen_xtal


def convert_mattergen(
	gen_xtal_raw: list[str],
) -> list[Crystal]:
	"""Convert generated crystals from Mattergen format to Crystal objects.

	Args:
		gen_xtal_raw (list[str]): Generated crystal structures.

	Returns:
		list[Crystal]: Converted crystal structures.
	"""
	gen_xtal = []
	for xtal in gen_xtal_raw:
		structure = CifParser.from_str(xtal).parse_structures(primitive=True)[0]
		gen_xtal.append(structure2crystal(structure))
	return gen_xtal


def convert(
	gen_xtal_raw: Any,
	method: Literal["cdvae", "diffcsp", "diffcsppp", "mattergen"],
) -> list[Crystal]:
	"""Convert generated crystals to Crystal objects.

	Args:
		gen_xtal_raw (Any): Generated crystal structures.
		method (Literal): Conversion method.

	Returns:
		list[Crystal]: Converted crystal structures.
	"""
	if method == "cdvae":
		return convert_cdvae(gen_xtal_raw)
	elif method == "diffcsp":
		return convert_diffcsp(gen_xtal_raw)
	elif method == "diffcsppp":
		return convert_diffcsppp(gen_xtal_raw)
	elif method == "mattergen":
		return convert_mattergen(gen_xtal_raw)
	else:
		raise ValueError(f"Unsupported conversion method: {method}")


def convert_mp20(
	train_xtal_raw: DataFrame,
) -> list[Crystal]:
	"""Convert crystals in the mp20 train dataset to Crystal objects.

	Args:
		train_xtal_raw (DataFrame): Crystals in the mp20 train dataset.

	Returns:
		list[Crystal]: Converted crystal structures.
	"""
	train_xtal = []
	for _, row in train_xtal_raw.iterrows():
		structure = CifParser.from_str(row["cif"]).parse_structures(primitive=True)[0]
		train_xtal.append(structure2crystal(structure))
	return train_xtal
