"""Evaluator class."""

import gzip
import json
import os
import pickle
import time
import warnings
from collections import Counter
from typing import Any, Literal

import amd
import numpy as np
import pandas as pd
import requests
from amd import PeriodicSet
from mace.calculators import mace_mp
from pymatgen.analysis.phase_diagram import PatchedPhaseDiagram, PDEntry
from pymatgen.analysis.structure_matcher import StructureMatcher
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
from pymatgen.entries.computed_entries import ComputedEntry
from pymatgen.ext.matproj import MPRester
from scipy.spatial.distance import squareform
from smact.screening import smact_validity

from xgemval.convert import convert_mp20
from xgemval.crystal import Crystal


class Evaluator:
	"""Evaluator class for evaluating generated crystals."""

	def __init__(
		self,
		gen_xtal: list[Crystal],
		dataset: Literal["mp20"],
		train_data_path: str,
		save_dir: str,
		phase_dir: str | None = None,
	) -> None:
		"""Initialize the Evaluator.

		Args:
			gen_xtal (list[Crystal]): Generated crystal structures.
			dataset (Literal): Dataset for novelty evaluation.
			train_data_path (str): Path to the training dataset.
			save_dir (str): Directory to save evaluation results.
			phase_dir (str | None): Directory to save phase diagram data.
		"""
		self.gen_xtal = gen_xtal
		self.dataset = dataset
		self.train_data_path = train_data_path
		self.train_data_dir = os.path.dirname(train_data_path)
		self.save_dir = save_dir
		os.makedirs(self.train_data_dir, exist_ok=True)
		os.makedirs(self.save_dir, exist_ok=True)
		if phase_dir is not None:
			self.phase_dir = phase_dir
			os.makedirs(self.phase_dir, exist_ok=True)

	def _prepare_train_xtal(self) -> None:
		"""Prepare the training crystals for novelty evaluation."""
		if hasattr(self, "train_xtal"):
			return
		if self.dataset == "mp20":
			# download the dataset from web
			if not os.path.exists(self.train_data_path):
				url = "https://raw.githubusercontent.com/txie-93/cdvae/refs/heads/main/data/mp_20/train.csv"
				urldata = requests.get(url).text
				with open(self.train_data_path, "w") as f:
					f.write(urldata)
			# convert the dataset to Crystal objects
			train_xtal_raw = pd.read_csv(self.train_data_path)
			self.train_xtal = convert_mp20(train_xtal_raw=train_xtal_raw)
		else:
			raise ValueError(f"Unsupported dataset: {self.dataset}.")

	def _embed(
		self,
		distance: Literal["comp", "wyckoff", "smat", "magpie", "pdd", "amd"],
		category: Literal["generated", "train"],
	) -> tuple[Any, float]:
		"""Compute the embeddings of crystals.

		Args:
			distance (Literal): Type of distance.
			category (Literal): Category of crystals to compute embeddings for.

		Returns:
			tuple: A tuple containing the embeddings of crystals
				   and the time taken to compute them.
		"""
		save_path = (
			os.path.join(self.save_dir, f"gen_{distance}.pkl.gz")
			if category == "generated"
			else os.path.join(self.train_data_dir, f"train_{distance}.pkl.gz")
		)
		time_embed = -1
		if os.path.exists(save_path):
			with gzip.open(save_path, "rb") as f:
				embeddings = pickle.load(f)
		else:
			xtals = self.gen_xtal if category == "generated" else self.train_xtal
			start_time_embed = time.time()
			if distance == "comp":
				embeddings = [xtal.get_composition() for xtal in xtals]
			elif distance == "wyckoff":
				embeddings = [xtal.get_wyckoff() for xtal in xtals]
			elif distance == "smat":
				embeddings = [xtal.get_structure() for xtal in xtals]
			elif distance == "magpie":
				embeddings = np.array([xtal.get_magpie() for xtal in xtals])
			elif distance == "pdd":
				psets = [xtal.get_periodicset() for xtal in xtals]
				embeddings = [
					amd.PDD(pset, k=100) if isinstance(pset, PeriodicSet) else pset
					for pset in psets
				]
			elif distance == "amd":
				psets = [xtal.get_periodicset() for xtal in xtals]
				embeddings = [
					amd.AMD(pset, k=100) if isinstance(pset, PeriodicSet) else pset
					for pset in psets
				]
			else:
				raise ValueError(f"Unsupported distance: {distance}.")
			end_time_embed = time.time()
			time_embed = end_time_embed - start_time_embed
			with gzip.open(save_path, "wb") as f:
				pickle.dump(embeddings, f)
		return embeddings, time_embed

	def _distance_matrix(
		self,
		metric: Literal["uniqueness", "novelty"],
		distance: Literal["comp", "wyckoff", "smat", "magpie", "pdd", "amd"],
		embeddings_1: Any,
		embeddings_2: Any,
	) -> tuple[np.ndarray, float]:
		"""Compute the distance matrix for the embeddings.

		Args:
			metric (Literal): Metric to compute later.
			distance (Literal): Type of distance.
			embeddings_1 (Any): Embeddings of crystals.
			embeddings_2 (Any): Embeddings of crystals to compare against.

		Returns:
			tuple[np.ndarray, float]: A tuple containing the distance matrix
				and the time taken to compute it.
		"""
		if metric == "uniqueness":
			path = os.path.join(self.save_dir, f"mtx_uni_{distance}.pkl.gz")
		else:  # novelty
			path = os.path.join(self.save_dir, f"mtx_nov_{distance}.pkl.gz")
		if os.path.exists(path):
			with gzip.open(path, "rb") as f:
				d_mtx = pickle.load(f)
			time_matrix = -1
		else:
			start_time_matrix = time.time()
			if distance == "comp":
				d_mtx = np.ones((len(embeddings_1), len(embeddings_2)))
				for i, emb_i in enumerate(embeddings_1):
					for j, emb_j in enumerate(embeddings_2):
						if emb_i == emb_j:
							d_mtx[i, j] = 0
			elif distance == "wyckoff":
				d_mtx = np.ones((len(embeddings_1), len(embeddings_2)))
				for i, emb_i in enumerate(embeddings_1):
					if isinstance(emb_i, str):  # error
						d_mtx[i, :] = -1
					for j, emb_j in enumerate(embeddings_2):
						if isinstance(emb_j, str):  # error
							d_mtx[i, j] = -1
						elif emb_i == emb_j:
							d_mtx[i, j] = 0
			elif distance == "smat":
				d_mtx = np.ones((len(embeddings_1), len(embeddings_2)))
				matcher = StructureMatcher()
				for i, emb_i in enumerate(embeddings_1):
					for j, emb_j in enumerate(embeddings_2):
						if matcher.fit(emb_i, emb_j):
							d_mtx[i, j] = 0
			elif distance == "magpie":
				d_mtx = np.zeros((len(embeddings_1), len(embeddings_2)))
				for i, emb in enumerate(embeddings_1):
					d_sq = (emb[np.newaxis, :] - embeddings_2) ** 2
					d_euclidean = np.sqrt(np.sum(d_sq, axis=1))
					d_mtx[i, :] = d_euclidean
			elif distance in ["pdd", "amd"]:
				valids_1 = [x for x in embeddings_1 if isinstance(x, np.ndarray)]
				error_indices_1 = [
					i for i, x in enumerate(embeddings_1) if isinstance(x, str)
				]
				valids_2 = [x for x in embeddings_2 if isinstance(x, np.ndarray)]
				error_indices_2 = [
					i for i, x in enumerate(embeddings_2) if isinstance(x, str)
				]
				if metric == "uniqueness":
					d_mtx = squareform(
						amd.PDD_pdist(valids_1, n_jobs=1)
						if distance == "pdd"
						else amd.AMD_pdist(valids_1)
					)
				else:
					d_mtx = (
						amd.PDD_cdist(valids_1, valids_2, n_jobs=1)
						if distance == "pdd"
						else amd.AMD_cdist(valids_1, valids_2)
					)
				for i in error_indices_1:
					d_mtx = np.insert(d_mtx, i, -1, axis=0)
				for i in error_indices_2:
					d_mtx = np.insert(d_mtx, i, -1, axis=1)
				assert d_mtx.shape == (len(embeddings_1), len(embeddings_2))
			end_time_matrix = time.time()
			time_matrix = end_time_matrix - start_time_matrix
			with gzip.open(path, "wb") as f:
				pickle.dump(d_mtx, f)
		return d_mtx, time_matrix

	def _smact_screen(self, xtals: list[Crystal]) -> np.ndarray[bool]:
		"""Screen the crystals using SMACT.

		Args:
			xtals (list[Crystal]): List of crystals to screen.

		Returns:
			np.ndarray[bool]: Array indicating which crystals pass the screening.
		"""
		path = os.path.join(self.save_dir, "screen_smact.pkl.gz")
		if os.path.exists(path):
			with gzip.open(path, "rb") as f:
				screened = pickle.load(f)
		else:
			screened = np.array(
				[smact_validity(xtal.get_composition_pymatgen()) for xtal in xtals]
			)
			with gzip.open(path, "wb") as f:
				pickle.dump(screened, f)
		return screened

	def _ehull_screen(self, xtals: list[Crystal]) -> np.ndarray[bool]:
		"""Screen the crystals using the energy above hull computed for MP.

		Args:
			xtals (list[Crystal]): List of crystals to screen.

		Returns:
			np.ndarray[bool]: Array indicating which crystals pass the screening.
		"""
		if os.path.exists(os.path.join(self.save_dir, "screen_ehull.pkl.gz")):
			with gzip.open(
				os.path.join(self.save_dir, "screen_ehull.pkl.gz"), "rb"
			) as f:
				screened = pickle.load(f)
		else:
			MP_API_KEY = os.getenv("MP_API_KEY")
			# pre-compute phase diagram
			assert hasattr(self, "phase_dir"), "Phase directory is not set."
			phase_path = os.path.join(
				self.phase_dir, "ppd-mp_all_entries_uncorrected_250618.pkl.gz"
			)
			if os.path.exists(phase_path):
				with gzip.open(phase_path, "rb") as f:
					ppd_mp = pickle.load(f)
			else:
				mpr = MPRester(MP_API_KEY)
				response = mpr.request("materials/thermo/?_fields=entries&formula=")
				all_entries = []
				for dct in response:
					all_entries.extend(dct["entries"].values())
				with warnings.catch_warnings():
					warnings.filterwarnings(
						"ignore", message="Failed to guess oxidation states.*"
					)
					all_entries = MaterialsProject2020Compatibility().process_entries(
						all_entries, clean=True
					)
				all_entries = list(set(all_entries))  # 152976 entries
				all_entries = [
					e for e in all_entries if e.data["run_type"] in ["GGA", "GGA_U"]
				]  # 152976 entries
				all_entries_uncorrected = [
					PDEntry(composition=e.composition, energy=e.uncorrected_energy)
					for e in all_entries
				]  # 152976 entries
				ppd_mp = PatchedPhaseDiagram(all_entries_uncorrected)
				with gzip.open(phase_path, "wb") as f:
					pickle.dump(ppd_mp, f)
			# compute energy above hull for each generated crystal
			calculator = mace_mp(model="medium-mpa-0")
			screened = np.zeros(len(xtals), dtype=bool)
			e_above_hulls = np.zeros(len(xtals), dtype=float)
			for idx, xtal in enumerate(xtals):
				try:
					mace_energy = calculator.get_potential_energy(xtal.get_ase_atoms())
					gen_entry = ComputedEntry(
						xtal.get_composition_pymatgen(), mace_energy
					)
					e_above_hulls[idx] = ppd_mp.get_e_above_hull(
						gen_entry, allow_negative=True
					)
					if e_above_hulls[idx] <= 0.1:
						screened[idx] = True
				except ValueError:
					# Samples contain elements not in the MP20 train set
					# Or no suitable PhaseDiagrams found
					screened[idx] = False
					e_above_hulls[idx] = np.nan
			with gzip.open(
				os.path.join(self.save_dir, "screen_ehull.pkl.gz"), "wb"
			) as f:
				pickle.dump(screened, f)
			with gzip.open(os.path.join(self.save_dir, "ehull.pkl.gz"), "wb") as f:
				pickle.dump(e_above_hulls, f)
		return screened

	def _save_metrics(
		self,
		distance: Literal["comp", "wyckoff", "smat", "magpie", "pdd", "amd"],
		path: str,
		metrics: dict,
	) -> None:
		"""Save the metrics to a JSON file.

		Args:
			distance (Literal): Distance used for metrics.
			path (str): Path to save the metrics.
			metrics (dict): Metrics to save.
		"""
		if os.path.exists(path):
			with open(path) as f:
				existing_metrics = json.load(f)
		else:
			existing_metrics = {}
		if distance not in existing_metrics:
			existing_metrics[distance] = {}
		existing_metrics[distance].update(metrics)
		with open(path, "w") as f:
			json.dump(existing_metrics, f, indent=4, sort_keys=True)

	def uniqueness(
		self,
		distance: Literal["comp", "wyckoff", "smat", "magpie", "pdd", "amd"],
		screen: Literal[None, "smact", "ehull"] = None,
	) -> None:
		"""Evaluate the uniqueness of generated crystals.

		Args:
			distance (Literal): Distance used for uniqueness evaluation.
			screen (Literal): Method to screen the generated crystals.
		"""
		if distance not in ["comp", "wyckoff", "smat", "magpie", "pdd", "amd"]:
			raise ValueError(f"Unsupported distance: {distance}.")

		embeddings, time_embed_gen = self._embed(
			distance=distance, category="generated"
		)

		# distance matrix
		d_mtx, time_matrix = self._distance_matrix(
			metric="uniqueness",
			distance=distance,
			embeddings_1=embeddings,
			embeddings_2=embeddings,
		)

		# screening
		valid_indices = np.ones(len(embeddings), dtype=bool)
		if distance in ["wyckoff", "pdd", "amd"]:
			valid_indices &= np.array([not isinstance(x, str) for x in embeddings])
		if screen == "smact":
			valid_indices &= self._smact_screen(xtals=self.gen_xtal)
		elif screen == "ehull":
			valid_indices &= self._ehull_screen(xtals=self.gen_xtal)
		d_mtx = d_mtx[valid_indices][:, valid_indices]

		# compute uniqueness
		start_time_metric = time.time()
		if distance in ["comp", "wyckoff", "smat"]:
			n_unique = sum(
				[1 if np.all(d_mtx[i, :i] != 0) else 0 for i in range(len(d_mtx))]
			)
			uniqueness = n_unique / len(embeddings)
		elif distance in ["magpie", "pdd", "amd"]:
			uniqueness = float(
				np.sum(d_mtx) / (len(embeddings) * (len(embeddings) - 1))
			)
		end_time_metric = time.time()
		time_metric = end_time_metric - start_time_metric

		# save metrics
		if screen is None:
			metrics = {
				"uniqueness": uniqueness,
				"time_uniqueness": time_metric,
			}
		elif screen == "smact":
			metrics = {
				"uniqueness_smact": uniqueness,
			}
		elif screen == "ehull":
			metrics = {
				"uniqueness_ehull": uniqueness,
			}
		if time_embed_gen != -1:
			metrics["time_embed_gen"] = time_embed_gen
		if time_matrix != -1:
			metrics["time_uniqueness_matrix"] = time_matrix
		if distance in ["wyckoff", "pdd", "amd"]:
			errors = [x for x in embeddings if isinstance(x, str)]
			if len(errors) > 0:
				metrics["error_embed_gen"] = dict(Counter(errors))
		self._save_metrics(
			distance=distance,
			path=os.path.join(self.save_dir, "metrics.json"),
			metrics=metrics,
		)

	def novelty(
		self,
		distance: Literal["comp", "wyckoff", "smat", "magpie", "pdd", "amd"],
		screen: Literal[None, "smact", "ehull"] = None,
	) -> None:
		"""Evaluate the novelty of generated crystals.

		Args:
			distance (Literal): Distance used for novelty evaluation.
			screen (Literal): Method to screen the generated crystals.
		"""
		if distance not in ["comp", "wyckoff", "smat", "magpie", "pdd", "amd"]:
			raise ValueError(f"Unsupported distance: {distance}.")

		self._prepare_train_xtal()

		gen_embeddings, time_embed_gen = self._embed(
			distance=distance, category="generated"
		)
		train_embeddings, time_embed_train = self._embed(
			distance=distance, category="train"
		)

		# distance matrix
		d_mtx, time_matrix = self._distance_matrix(
			metric="novelty",
			distance=distance,
			embeddings_1=gen_embeddings,
			embeddings_2=train_embeddings,
		)

		# screening
		gen_valid_indices = np.ones(len(gen_embeddings), dtype=bool)
		train_valid_indices = np.ones(len(train_embeddings), dtype=bool)
		if distance in ["wyckoff", "pdd", "amd"]:
			gen_valid_indices &= np.array(
				[not isinstance(x, str) for x in gen_embeddings]
			)
			train_valid_indices &= np.array(
				[not isinstance(x, str) for x in train_embeddings]
			)
		if screen == "smact":
			gen_valid_indices &= self._smact_screen(xtals=self.gen_xtal)
			# smact screening is not applied to training data
			# tmp = self._smact_screen(xtals=self.train_xtal)
			# print(np.sum(tmp))  # 8348
			# print(len(tmp))  # 10000
		elif screen == "ehull":
			gen_valid_indices &= self._ehull_screen(xtals=self.gen_xtal)
		d_mtx = d_mtx[gen_valid_indices][:, train_valid_indices]

		# compute novelty
		start_time_metric = time.time()
		if distance in ["comp", "wyckoff", "smat"]:
			n_novel = sum(
				[1 if np.all(d_mtx[i] != 0) else 0 for i in range(len(d_mtx))]
			)
			novelty = n_novel / len(gen_embeddings)
		elif distance in ["magpie", "pdd", "amd"]:
			novelty = float(np.sum(np.min(d_mtx, axis=1)) / len(gen_embeddings))
		end_time_metric = time.time()
		time_metric = end_time_metric - start_time_metric

		# save
		if screen is None:
			metrics = {
				"novelty": novelty,
				"time_novelty": time_metric,
			}
		elif screen == "smact":
			metrics = {
				"novelty_smact": novelty,
			}
		elif screen == "ehull":
			metrics = {
				"novelty_ehull": novelty,
			}
		if time_embed_gen != -1:
			metrics["time_embed_gen"] = time_embed_gen
		if time_matrix != -1:
			metrics["time_novelty_matrix"] = time_matrix
		if distance in ["wyckoff", "pdd", "amd"]:
			gen_errors = [x for x in gen_embeddings if isinstance(x, str)]
			if len(gen_errors) > 0:
				metrics["error_embed_gen"] = dict(Counter(gen_errors))
		self._save_metrics(
			distance=distance,
			path=os.path.join(self.save_dir, "metrics.json"),
			metrics=metrics,
		)

		if time_embed_train != -1:
			train_metrics = {"time_embed_train": time_embed_train}
			if distance in ["wyckoff", "pdd", "amd"]:
				train_errors = [x for x in train_embeddings if isinstance(x, str)]
				if len(train_errors) > 0:
					train_metrics["error_embed_train"] = dict(Counter(train_errors))
			self._save_metrics(
				distance=distance,
				path=os.path.join(self.train_data_dir, "metrics.json"),
				metrics=train_metrics,
			)
