"""Evaluate generated crystals."""

import argparse
import json
import os
import zipfile
from typing import Literal

import torch
from utils import get_path_to_gen_xtal, get_path_to_train_data

from xgemval import Evaluator, convert


def main(
	model: Literal[
		"cdvae", "diffcsp", "adit", "chemeleon", "chemeleon2", "diffcsppp", "mattergen"
	],
	dataset: Literal["mp20"],
	uniqueness: list[Literal["comp", "wyckoff", "smat", "magpie", "pdd", "amd"]],
	novelty: list[Literal["comp", "wyckoff", "smat", "magpie", "pdd", "amd"]],
	screen: Literal["none", "smact", "ehull"],
) -> None:
	"""Main function to evaluate generated crystals.

	Args:
		model (Literal):
			Model used to generate crystals. This determines the conversion method.
		dataset (Literal): Dataset for novelty evaluation.
		uniqueness (list[Literal]):
			Distances for uniqueness evaluation.
		novelty (list[Literal]):
			Distances for novelty evaluation.
		screen (Literal): Screening method.
	"""
	gen_xtal_path = get_path_to_gen_xtal(dataset=dataset, model=model)
	train_data_path = get_path_to_train_data(dataset=dataset)

	# load the generated crystal structures
	if model == "cdvae":
		gen_xtal_raw = torch.load(gen_xtal_path, weights_only=False)
		convert_method = "cdvae"
	elif model in ["diffcsp", "adit", "chemeleon", "chemeleon2"]:
		with open(gen_xtal_path) as f:
			gen_xtal_raw = json.load(f)
		convert_method = "diffcsp"
	elif model == "diffcsppp":
		gen_xtal_raw = torch.load(gen_xtal_path, weights_only=False)
		convert_method = "diffcsppp"
	elif model == "mattergen":
		gen_xtal_raw = []
		with zipfile.ZipFile(gen_xtal_path, "r") as zf:
			all_files = zf.namelist()
			for filename in all_files:
				if filename.endswith(".cif"):
					with zf.open(filename) as f:
						content = f.read().decode("utf-8")
						gen_xtal_raw.append(content)
		convert_method = "mattergen"
	else:
		raise ValueError(f"Unsupported model: {model}")

	# convert the generated crystal structures to Crystal objects
	gen_xtal = convert(gen_xtal_raw=gen_xtal_raw, method=convert_method)

	# initialize the evaluator
	evaluator = Evaluator(
		gen_xtal=gen_xtal,
		dataset=dataset,
		train_data_path=train_data_path,
		save_dir=os.path.dirname(gen_xtal_path),
		phase_dir=os.path.join(os.path.dirname(__file__), "asset")
		if screen == "ehull"
		else None,
	)

	# evaluation
	if screen == "none":
		screen = None
	for distance in uniqueness:
		evaluator.uniqueness(distance=distance, screen=screen)
	for distance in novelty:
		evaluator.novelty(distance=distance, screen=screen)


if __name__ == "__main__":
	parser = argparse.ArgumentParser()
	parser.add_argument(
		"--model",
		type=str,
		required=True,
		choices=[
			"cdvae",
			"diffcsp",
			"adit",
			"chemeleon",
			"chemeleon2",
			"diffcsppp",
			"mattergen",
		],
		help="Model used to generate crystals.",
	)
	parser.add_argument(
		"--dataset",
		type=str,
		required=True,
		choices=["mp20"],
		help="Dataset for novelty evaluation.",
	)
	parser.add_argument(
		"--uniqueness",
		type=str,
		required=True,
		nargs="*",
		choices=["comp", "wyckoff", "smat", "magpie", "pdd", "amd"],
		help="Distance for uniqueness evaluation.",
	)
	parser.add_argument(
		"--novelty",
		type=str,
		required=True,
		nargs="*",
		choices=["comp", "wyckoff", "smat", "magpie", "pdd", "amd"],
		help="Distance for novelty evaluation.",
	)
	parser.add_argument(
		"--screen",
		type=str,
		required=True,
		choices=["none", "smact", "ehull"],
		help="Screening method.",
	)
	args = parser.parse_args()
	main(**vars(args))
