#!/usr/bin/python3
"""
Defines a neighbor-based equivalence relation.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import numpy as np
import os
from llama_index.core import (
    Document,
    Settings,
    StorageContext,
    VectorStoreIndex,
    load_index_from_storage
)
from llama_index.core.embeddings import BaseEmbedding
from pathlib import Path
from pydantic import BaseModel
from sklearn.cluster import KMeans  # type: ignore
from sklearn.decomposition import PCA  # type: ignore
from typing import Any, Dict, Final, Hashable, List, Union

from .base import BaseEquivalenceRelation
from ..envs import BaseTask


class KMeansEquivalenceRelation(BaseEquivalenceRelation):
    def __init__(
        self,
        task: BaseTask,
        embedder: BaseEmbedding,
        cache_dir: Union[Path, str] = (
            Path.home() / ".cache" / "leon" / "neighbors"
        ),
        seed: int = 2025,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            embedder: the embedding model to use.
            train: the training dataset.
            cache_dir: the directory to cache the neighbors.
            seed: the random seed.
        """
        del kwargs
        Settings.embed_model = embedder
        self._task: Final[BaseTask] = task
        self._embedder: Final[str] = embedder.model_name
        self._cache_dir: Final[str] = os.path.join(
            str(cache_dir), self._task.train.__class__.__name__
        )
        os.makedirs(self._cache_dir, exist_ok=True)
        self.seed: Final[int] = seed
        self._index: Final[VectorStoreIndex] = self._build_index()
        embeddings = np.vstack(
            list(self._index.vector_store.to_dict()["embedding_dict"].values())
        )
        self._pca = PCA(n_components=min(*embeddings.shape))
        self._pca.fit(embeddings)
        self.n_components = 1 + np.argmax(
            np.cumsum(self._pca.explained_variance_ratio_) >= 0.99
        )

        embeddings = self._pca.transform(embeddings)[:, :self.n_components]
        n_classes = self._compute_optimal_k(embeddings)
        super(KMeansEquivalenceRelation, self).__init__(n_classes)
        self._kmeans = KMeans(n_clusters=self._n_classes, random_state=seed)
        self._kmeans.fit(embeddings)
        self._inertia: Final[float] = self._kmeans.inertia_

    def __call__(self, x: List[BaseModel], s: List[float]) -> List[Hashable]:
        """
        Args:
            x: a list of designs to assign to their equivalence classes.
            s: a list of the corresponding critic-augmented scores.
        Returns:
            A list of the equivalence classes for each design.
        """
        assert len(x) == len(s)
        del s
        embeddings = np.vstack([
            Settings.embed_model.get_text_embedding(repr(xi)) for xi in x
        ])
        embeddings = self._pca.transform(embeddings)[:, :self.n_components]
        return [str(idx) for idx in self._kmeans.predict(embeddings)]

    def _build_index(self) -> VectorStoreIndex:
        """
        Builds the index of the training dataset embeddings.
        Input:
            None.
        Returns:
            A VectorStoreIndex object.
        """
        cache_path = os.path.join(self._cache_dir, "index")
        if os.path.isdir(cache_path):
            index: VectorStoreIndex = load_index_from_storage(  # type: ignore
                StorageContext.from_defaults(persist_dir=cache_path)
            )
        else:
            index = VectorStoreIndex.from_documents([
                Document(text=repr(xi)) for xi in self._task.reduce([
                    self._task.train[i] for i in range(len(self._task.train))
                ])
            ])
            if cache_path is not None:
                index.storage_context.persist(persist_dir=cache_path)
        return index

    def _compute_optimal_k(
        self, embeddings: np.ndarray, kmax: int = 10
    ) -> int:
        """
        Computes the optimal number of clusters using the elbow method.
        Args:
            embeddings: the embeddings to cluster.
        Returns:
            The optimal number of clusters.
        """
        ks = list(range(2, kmax + 1))
        inertias = {}
        for k in ks:
            km = KMeans(n_clusters=k, random_state=self.seed)
            km.fit(embeddings)
            inertias[k] = km.inertia_

        x = np.array(ks, dtype=float)
        y = np.array([inertias[k] for k in ks], dtype=float)

        x1, y1 = x[0], y[0]
        x2, y2 = x[-1], y[-1]
        denom = np.sqrt(np.square(y2 - y1) + np.square(x2 - x1))

        dists = (1.0 / denom) * np.abs(
            ((y2 - y1) * x) - ((x2 - x1) * y) + (x2 * y1) - (y2 * x1)
        )

        return ks[np.argmax(dists)]
