"""Utility functions for experiments."""

from typing import Literal

from path import (
	PATH_MP20,
	PATH_MP20_ADIT,
	PATH_MP20_CDVAE,
	PATH_MP20_CHEMELEON,
	PATH_MP20_CHEMELEON2,
	PATH_MP20_DIFFCSP,
	PATH_MP20_DIFFCSPPP,
	PATH_MP20_MATTERGEN,
)


def get_path_to_train_data(dataset: Literal["mp20"]) -> str:
	"""Get the path to the training data based on the dataset.

	Args:
		dataset (Literal): The dataset for which to get the training data path.

	Returns:
		str: Path to the training data.
	"""
	if dataset == "mp20":
		return PATH_MP20
	else:
		raise ValueError(f"Unknown dataset: {dataset}")


def get_path_to_gen_xtal(
	dataset: Literal["mp20"],
	model: Literal[
		"cdvae", "diffcsp", "adit", "chemeleon", "chemeleon2", "diffcsppp", "mattergen"
	],
) -> str:
	"""Get the path to the generated crystal structures based on the model.

	Args:
		dataset (Literal): The dataset on which the model was trained.
		model (Literal): The model used to generate crystals.

	Returns:
		str: Path to the generated crystal structures.
	"""
	if dataset == "mp20":
		if model == "cdvae":
			return PATH_MP20_CDVAE
		elif model == "diffcsp":
			return PATH_MP20_DIFFCSP
		elif model == "adit":
			return PATH_MP20_ADIT
		elif model == "chemeleon":
			return PATH_MP20_CHEMELEON
		elif model == "chemeleon2":
			return PATH_MP20_CHEMELEON2
		elif model == "diffcsppp":
			return PATH_MP20_DIFFCSPPP
		elif model == "mattergen":
			return PATH_MP20_MATTERGEN
		else:
			raise ValueError(f"Unknown model: {model}")
	else:
		raise ValueError(f"Unknown dataset: {dataset}")
