#!/usr/bin/python3
"""
Logger to track the optimization state.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import logging
import pandas as pd
import numpy as np
import os
from pydantic import BaseModel
from pathlib import Path
from typing import Any, Dict, Final, List, NamedTuple, Optional, Union

from ..envs.base import BaseTask


logger = logging.getLogger(__name__)


class OptimizerState:
    def __init__(
        self,
        task: BaseTask,
        optimizer_name: str,
        individual: NamedTuple,
        savedir: Optional[Union[Path, str]] = None,
        max_samples: int = 2048,
        seed: Optional[int] = None,
        **kwargs
    ):
        """
        Args:
            task: the optimization task.
            optimizer_name: the name of the backbone optimizer.
            individual: the individual to optimize.
            savedir: the directory to save the optimization results to.
            max_samples: maximum samples allowed. Default 2048.
            seed: optional random seed. Default None.
        """
        self.task: Final[BaseTask] = task
        self.optimizer_name: Final[str] = optimizer_name
        self.individual: Final[NamedTuple] = individual
        self.savedir: Final[Optional[Union[Path, str]]] = savedir
        self.max_samples: Final[int] = max_samples
        self.num_fails = 0
        self.best_yq = -np.inf
        self.seed = seed

        for key, val in kwargs.items():
            setattr(self, key, val)

        self.xq: List[BaseModel] = []
        self.yq: np.ndarray = np.array([])
        self.metadata: Dict[str, Any] = {}

    def log(
        self, xq: List[BaseModel], yq: np.ndarray, **metadata: Any
    ) -> None:
        """
        Evaluates and records a set of proposed designs.
        Input:
            xq: a list of proposed designs of shape N, where N is the number
                of proposed designs.
            yq: an array of corresponding fitness scores of shape N.
            metadata: additional metadata to record.
        Returns:
            None.
        """
        self.xq.extend(xq)
        self.yq = np.concatenate([self.yq, yq])
        for key, val in metadata.items():
            if key not in self.metadata:
                self.metadata[key] = val
            else:
                self.metadata[key] += max(val, 0)
        logger.info(
            "Most Recently Sampled Batch:\n"
            f"{self.memory.iloc[-len(xq):].to_markdown()}"
        )

    def __len__(self) -> int:
        """
        Returns the number of designs in the state.
        Input:
            None.
        Returns:
            The number of designs in the state.
        """
        return len(self.xq)

    @property
    def designs(self) -> List[BaseModel]:
        """
        Returns a list of all of the previously sampled designs since the
        most recent restart.
        Input:
            None.
        Returns:
            A list of all the previously sampled designs of shape N, where N
            is the number of previously sampled designs.
        """
        return self.xq

    @property
    def predictions(self) -> np.ndarray:
        """
        Returns a tensor of all of the predictions associated with the designs.
        Input:
            None.
        Returns:
            A tensor of all the previous predictions of shape N, where N is the
            number of previously sampled designs.
        """
        return self.yq

    @property
    def has_converged(self) -> bool:
        """
        Returns whether the optimization has converged.
        Input:
            None.
        Returns:
            Whether the optimization has converged.
        """
        return len(self.xq) >= self.max_samples

    def evaluate_and_save(
        self, fast_dev_run: bool, **kwargs: Dict[str, Any]
    ) -> None:
        """
        Evaluates and saves the best designs and their corresponding scores.
        Input:
            fast_dev_run: whether we are running a fast development run.
        Returns:
            None.
        """
        if self.savedir is None or fast_dev_run:
            return
        savedir = os.path.join(self.savedir, self.task.task_name)
        os.makedirs(savedir, exist_ok=True)

        designs = self.task.extend(self.xq, self.individual)
        yq = self.yq.reshape(-1)
        idxs = np.argsort(yq)[-len(designs):]
        predictions = yq.reshape(-1)[idxs]
        scores = np.array(self.task.predict([designs[i] for i in idxs]))

        fstring = "{optimizer_name}-{seed}-{index}.npz"
        try:
            np_designs = np.array([list(x) for x in designs])
        except ValueError:
            np_designs = np.array([list(x) for x in designs], dtype=object)
        kwargs.update(self.metadata)  # type: ignore
        np.savez(
            os.path.join(
                savedir,
                fstring.format(
                    optimizer_name=self.optimizer_name,
                    seed=self.seed,
                    index=getattr(self.individual, "id_", None)
                )
            ),
            designs=np_designs,
            predictions=predictions,
            scores=scores,
            **kwargs  # type: ignore
        )

    @property
    def memory(self) -> pd.DataFrame:
        """
        Returns the current memory of the optimizer in the current restart as
        a markdown table.
        Input:
            None.
        Returns:
            The current memory of the optimizer.
        """
        return pd.DataFrame({"designs": self.xq, "scores": self.yq})
