#!/usr/bin/python3
"""
Defines a community-based equivalence relation.

Author(s):
    Anonymized Authors @anonymized-authors

Citation(s):
    [1] Traag VA, Waltman L, van Eck NJ. From Louvain to Leiden:
        Guaranteeing well-connected communities. Sci Rep 9(5233).
        (2019). doi: 10.1038/s41598-019-41695-z

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import igraph as ig  # type: ignore
import leidenalg as la  # type: ignore
import numpy as np
import os
from llama_index.core import (
    Document,
    Settings,
    StorageContext,
    VectorStoreIndex,
    load_index_from_storage
)
from llama_index.core.embeddings import BaseEmbedding
from pathlib import Path
from pydantic import BaseModel
from sklearn.decomposition import PCA  # type: ignore
from sklearn.neighbors import NearestNeighbors  # type: ignore
from typing import Any, Dict, Final, Hashable, List, Union

from .base import BaseEquivalenceRelation
from ..envs import BaseTask


class CommunityBasedEquivalenceRelation(BaseEquivalenceRelation):
    def __init__(
        self,
        task: BaseTask,
        embedder: BaseEmbedding,
        sigma: float = 0.1,
        cache_dir: Union[Path, str] = (
            Path.home() / ".cache" / "leon" / "neighbors"
        ),
        seed: int = 2025,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            embedder: the embedding model to use.
            cache_dir: the directory to cache the neighbors.
            seed: random seed.
        """
        del kwargs
        Settings.embed_model = embedder
        self._task: Final[BaseTask] = task
        self._embedder: Final[str] = embedder.model_name
        self._sigma: Final[float] = sigma
        self._seed: Final[int] = seed
        self._cache_dir: Final[str] = os.path.join(
            str(cache_dir), self._task.train.__class__.__name__
        )
        os.makedirs(self._cache_dir, exist_ok=True)
        self._index: Final[VectorStoreIndex] = self._build_index()
        embeddings = np.vstack(
            list(self._index.vector_store.to_dict()["embedding_dict"].values())
        )
        self._pca = PCA(n_components=min(*embeddings.shape))
        self._pca.fit(embeddings)
        self.n_components = 1 + np.argmax(
            np.cumsum(self._pca.explained_variance_ratio_) >= 0.99
        )

        embeddings = self._pca.transform(embeddings)[:, :self.n_components]
        k = max(10, int(np.ceil(np.log(embeddings.shape[0]))))
        self._nbrs = NearestNeighbors(n_neighbors=(k + 1), metric="cosine")
        self._nbrs.fit(embeddings)

        dists, neighs = self._nbrs.kneighbors(embeddings, return_distance=True)
        dists, neighs = dists[..., 1:], neighs[..., 1:]
        weights = np.exp(-1.0 * dists * dists / (2.0 * np.square(self._sigma)))
        edges = [(i, j) for i in range(len(neighs)) for j in neighs[i]]
        self._graph = ig.Graph(
            n=len(embeddings),
            edges=edges,
            directed=False,
            edge_attrs={"weight": weights.flatten()}
        )
        self._optimiser = la.Optimiser()
        self._optimiser.set_rng_seed(self._seed)
        self._partition = la.CPMVertexPartition(
            self._graph, resolution_parameter=0.001
        )
        self._optimiser.optimise_partition(self._partition)
        self._optimiser.consider_empty_community = False

        super(CommunityBasedEquivalenceRelation, self).__init__(
            n_classes=len(set(self._partition.membership))
        )

    def __call__(self, x: List[BaseModel], s: List[float]) -> List[Hashable]:
        """
        Args:
            x: a list of designs to assign to their equivalence classes.
            s: a list of the corresponding critic-augmented scores.
        Returns:
            A list of the equivalence classes for each design.
        """
        assert len(x) == len(s)
        del s
        embeddings = np.vstack([
            Settings.embed_model.get_text_embedding(repr(xi)) for xi in x
        ])
        embeddings = self._pca.transform(embeddings)[:, :self.n_components]
        adj_csr = self._nbrs.kneighbors_graph(X=embeddings, mode="distance")
        adj_matrix = adj_csr.toarray()
        adj_matrix = np.where(
            adj_matrix > 0.0,
            np.exp(-adj_matrix * adj_matrix / (2.0 * np.square(self._sigma))),
            0.0
        )
        new_edges = np.where(adj_matrix > 0.0)
        c: List[Hashable] = []
        for i in range(len(x)):
            self._optimiser.set_rng_seed(self._seed)
            graph = self._graph.copy()
            new_node_ids = [graph.add_vertex()]
            initial_membership = self._partition.membership + [0]
            for row, col in zip(*new_edges):
                if row != i:
                    continue
                graph.add_edge(
                    new_node_ids[0],
                    self._graph.vs[col].index,
                    weight=adj_matrix[row, col]
                )

            new_partition = la.CPMVertexPartition(
                graph,
                initial_membership=initial_membership,
                resolution_parameter=self._partition.resolution_parameter
            )
            is_membership_fixed = ([True] * self._graph.vcount()) + [False]
            self._optimiser.optimise_partition(
                new_partition,
                is_membership_fixed=is_membership_fixed,
                n_iterations=1
            )
            c.append(str(new_partition.membership[-1]))
        return c

    def _build_index(self) -> VectorStoreIndex:
        """
        Builds the index of the training dataset embeddings.
        Input:
            None.
        Returns:
            A VectorStoreIndex object.
        """
        cache_path = os.path.join(self._cache_dir, "index")
        if os.path.isdir(cache_path):
            index: VectorStoreIndex = load_index_from_storage(  # type: ignore
                StorageContext.from_defaults(persist_dir=cache_path)
            )
        else:
            index = VectorStoreIndex.from_documents([
                Document(text=repr(xi)) for xi in self._task.reduce([
                    self._task.train[i] for i in range(len(self._task.train))
                ])
            ])
            if cache_path is not None:
                index.storage_context.persist(persist_dir=cache_path)
        return index
