#!/usr/bin/python3
"""
Entropy-penalized Language Model-based Optimization (ELMO).

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
from __future__ import annotations
import logging
import numpy as np
import torch
from pydantic import BaseModel
from typing import (
    Any, Dict, Final, Hashable, List, NamedTuple, Optional, Tuple
)

from .base import BaseEquivalenceRelation
from ..utils import Pydantic2Rd
from ..envs import BaseTask
from ..model import LipschitzMLP


logger = logging.getLogger(__name__)


class QuotientSet:
    def __init__(self, equiv_relation: BaseEquivalenceRelation):
        """
        Args:
            equiv_relation: a function mapping designs to equivalence classes.
        """
        self._equiv_relation: Final[BaseEquivalenceRelation] = equiv_relation

        self._counts: Dict[Hashable, int] = {
            str(idx): 0 for idx in range(self._equiv_relation.num_classes)
        }
        self._best_s: Dict[Hashable, float] = {
            str(idx): -np.inf
            for idx in range(self._equiv_relation.num_classes)
        }
        self._best_x: Dict[Hashable, Optional[BaseModel]] = {
            str(idx): None
            for idx in range(self._equiv_relation.num_classes)
        }
        self.ninf: Final[float] = -10.0

    def assign(self, x: List[BaseModel], s: np.ndarray) -> List[Hashable]:
        """
        Assigns the equivalence‐class IDs for input designs, and updates the
        internal counts and best scores.
        Input:
            x: a list of input designs.
            s: the critic-rewarded scores of the input designs.
        Returns:
            The equivalence class IDs for the input designs.
        """
        cls = self._equiv_relation(x, s.tolist())
        for i, c in enumerate(cls):
            self._counts[c] += 1
            if s[i] > self._best_s[c]:
                self._best_x[c], self._best_s[c] = x[i], s[i]
        return cls

    @property
    def equivalence_classes(self) -> List[str]:
        """
        Returns a list of the equivalence classes with non-zero counts.
        Input:
            None.
        Returns:
            The list of equivalence classes with non-zero counts.
        """
        return sorted([str(key) for key in self._counts.keys()])

    def __len__(self) -> int:
        """
        Returns the total number of logged designs.
        Input:
            None.
        Returns:
            The total number of logged designs.
        """
        return int(sum(self._counts.values()))

    @property
    def fractional_occupancy(self) -> np.ndarray:
        """
        Returns the fractional occupancy of each equivalence class.
        Input:
            None.
        Returns:
            The fractional occupancy of each equivalence class.
        """
        return np.array([
            float(self._counts[cls]) / float(len(self))
            for cls in sorted(self.equivalence_classes)
        ])

    @property
    def s_star(self) -> np.ndarray:
        """
        Returns the score of the best design (according to the critic-rewarded
        proxy) in each equivalence class.
        Input:
            None.
        Returns:
            The best score of each equivalence class.
        """
        return np.array([
            self._best_s[cls] if not np.isinf(self._best_s[cls]) else self.ninf
            for cls in sorted(self.equivalence_classes)
        ])

    @property
    def x_star(self) -> List[Optional[BaseModel]]:
        """
        Returns the best design (according to the critic-rewarded proxy) in
        each equivalence class.
        Input:
            None.
        Returns:
            The best design of each equivalence class.
        """
        return [self._best_x[cls] for cls in sorted(self.equivalence_classes)]


class EntropyPenalizedTransform:
    name: Final[str] = "EntropyPenalizedTransform"

    def __init__(
        self,
        task: BaseTask,
        critic: LipschitzMLP,
        equiv_relation: BaseEquivalenceRelation,
        lambda_: Optional[float] = None,
        lambda_lr: float = 0.1,
        mu_: Optional[float] = None,
        W0: float = 1.0,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
            critic: the source critic model.
            equiv_relation: a function mapping designs to equivalence classes.
            lambda_: an optional fixed value for the lambda hyperparameter.
            lambda_lr: the learning rate for the lambda hyperparameter.
            mu_: an optional fixed value for the mu hyperparameter.
            W0: the upper bound on the 1-Wasserstein distance.
        """
        del kwargs
        self.eps: Final[float] = float(np.finfo(np.float32).eps)
        self.task: Final[BaseTask] = task
        self.critic: Final[LipschitzMLP] = critic
        self.quotient_set: Final[QuotientSet] = QuotientSet(equiv_relation)
        self.fixed_lambda_: Final[Optional[float]] = lambda_
        self.lambda_: float = self.fixed_lambda_ or self.eps
        self.lambda_lr: Final[float] = lambda_lr
        self.fixed_mu_: Final[Optional[float]] = mu_
        self.mu_: float = self.fixed_mu_ or 1.0
        self.W0: Final[float] = W0
        self.t: int = 0

    def __call__(self, x: List[NamedTuple]) -> Tuple[np.ndarray, float, float]:
        """
        Forward pass with entropy‐penalized regularization.
        Input:
            x: a list of designs to evaluate.
        Returns:
            The entropy-penalized predicted scores of each design, the value
            of mu_hat, and the R-squared value of the linear regression.
        """
        xt = torch.vstack([xi.as_tensor() for xi in x])  # type: ignore
        xt = xt.to(next(self.critic.parameters()))
        with torch.no_grad():
            fvals = np.array(self.task(x)).squeeze()
            cvals = self.critic(xt[..., :self.task.ndim])
            np_cvals = cvals.detach().cpu().numpy().squeeze()
            svals = fvals + (self.lambda_ * np_cvals)

        logger.info("### Status: Intermediate Design Scoring ###")
        for i, (xi, f, c, s) in enumerate(zip(x, fvals, cvals, svals)):
            logger.info(f"  Design {i}: {xi}")
            logger.info(f"    Surrogate Estimate: {f}")
            logger.info(f"    Source Critic Score: {c}")
            logger.info(f"    Total Estimated Score: {s}")

        self.quotient_set.assign(self.task.reduce(x), svals)

        if not len(self.quotient_set.equivalence_classes) or not len(
            self.quotient_set.equivalence_classes
        ):
            return svals

        p = np.maximum(self.quotient_set.fractional_occupancy, self.eps)

        dlogp = np.log(p) - np.log(p).mean()
        ds = self.quotient_set.s_star - self.quotient_set.s_star.mean()
        if self.fixed_mu_ is None:
            self.mu_ = np.maximum(
                self.eps, np.sum(ds * dlogp) / (np.sum(ds * ds) + self.eps)
            )
            b0 = np.log(p).mean() - (
                self.mu_ * self.quotient_set.s_star.mean()
            )
            ssres = np.sum(
                np.square(
                    np.log(p) - (b0 + (self.mu_ * self.quotient_set.s_star))
                )
            )
        else:
            self.mu_ = self.fixed_mu_
            ssres = 0.0

        sstot = np.sum(np.square(np.log(p) - np.log(p).mean()))
        return self.mu_ * svals, self.mu_, 1.0 - (ssres / (sstot + self.eps))

    def fit(
        self, Xp: List[BaseModel], Xq: List[BaseModel]
    ) -> EntropyPenalizedTransform:
        """
        Fits the source critic model.
        Input:
            Xp: the dataset of real designs.
            Xq: the dataset of `fake` generated designs.
        """
        self.t += 1
        P = torch.from_numpy(Pydantic2Rd(self.task, Xp, continuous=True))
        Q = torch.from_numpy(Pydantic2Rd(self.task, Xq, continuous=True))
        P = P.to(next(self.critic.parameters())).reshape(-1, self.task.ndim)
        Q = Q.to(next(self.critic.parameters())).reshape(-1, self.task.ndim)
        self.critic.fit(P, Q)

        grad = self.W0 - self.critic(P).mean().item()
        for pi, xistar in zip(
            self.quotient_set.fractional_occupancy, self.quotient_set.x_star
        ):
            if xistar is None:
                continue
            xistar_tensor = torch.from_numpy(
                Pydantic2Rd(self.task, [xistar], continuous=True)
            )
            xistar_tensor = xistar_tensor.to(next(self.critic.parameters()))
            xistar_tensor = xistar_tensor.reshape(-1, self.task.ndim)
            cistar = self.critic(xistar_tensor)
            grad += pi * cistar.item()
        if self.fixed_lambda_ is None:
            self.lambda_ = max(
                0.0, self.lambda_ - (self.lambda_lr * grad / np.sqrt(self.t))
            )
        return self


class IdentityTransform:
    name: Final[str] = "IdentityTransform"

    def __init__(
        self,
        task: BaseTask,
        *args: Any,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task: the optimization task.
        """
        del args, kwargs
        self.task: Final[BaseTask] = task

    def __call__(self, x: List[NamedTuple]) -> Tuple[np.ndarray, float, float]:
        """
        Forward pass through the identity regularizer (i.e., no regularization
        of the OOD objective function is applied).
        Input:
            xq: the input batch of designs to the OOD objective function.
        Returns:
            The forward model predictions of the input designs.
        """
        return np.array(self.task(x)), 1.0, 1.0

    def fit(self, *args: Any, **kwargs: Any) -> IdentityTransform:
        """
        Do nothing.
        Input:
            None.
        Returns:
            The original model.
        """
        del args, kwargs
        return self
