# =============================================================================
# Logger
# =============================================================================

import time
import logging
from pathlib import Path
from types import SimpleNamespace

import yaml

from utils.dataset import BiLevelObservedData
from utils.models import RFFHybridModel, RFFModelList



# -----------------------------------------------------------------------------
# Experiment Logger
# -----------------------------------------------------------------------------

class ExperimentLogger:
    """Logger for Bayesian optimization experiments."""

    def __init__(
        self,
        config: SimpleNamespace,
        local_path: Path,
    ) -> None:

        self.config = config
        self.local_path = local_path
        # Logger
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)
        fmt = logging.Formatter(
            fmt="[%(asctime)s] - %(message)s",
            datefmt="%Y-%m-%d %H:%M:%S"
        )
        # File Handler
        file = self.local_path / "experiment.log"
        file_handler = logging.FileHandler(file, mode="w")
        file_handler.setFormatter(fmt)
        self.logger.addHandler(file_handler)
        # Console Handler
        console_handler = logging.StreamHandler()
        console_handler.setFormatter(fmt)
        self.logger.addHandler(console_handler)


    def _save_config(self, config: SimpleNamespace) -> None:

        config_path = self.local_path / "config.yaml"
        with open(config_path, mode="w") as f:
            yaml.safe_dump(vars(config), f, sort_keys=False)
        self.logger.info(f"Configuration is saved to {config_path}")


    def _start_ascii(
        self,
    ) -> None:

        self.logger.info(r"==================================================")
        self.logger.info(r"       ____  ____     _____ __             __     ")
        self.logger.info(r"      / __ )/ __ \   / ___// /_____ ______/ /_    ")
        self.logger.info(r"     / __  / / / /   \__ \/ __/ __ `/ ___/ __/    ")
        self.logger.info(r"    / /_/ / /_/ /   ___/ / /_/ /_/ / /  / /_      ")
        self.logger.info(r"   /_____/\____/   /____/\__/\__,_/_/   \__/      ")
        self.logger.info(r"                                                  ")
        self.logger.info(r"==================================================")

    
    def _end_ascii(
        self,
    ) -> None:

        self.logger.info(r"==================================================")
        self.logger.info(r"       ____  ____     _______       _      __     ")
        self.logger.info(r"      / __ )/ __ \   / ____(_)___  (_)____/ /_    ")
        self.logger.info(r"     / __  / / / /  / /_  / / __ \/ / ___/ __ \   ")
        self.logger.info(r"    / /_/ / /_/ /  / __/ / / / / / (__  ) / / /   ")
        self.logger.info(r"   /_____/\____/  /_/   /_/_/ /_/_/____/_/ /_/    ")
        self.logger.info(r"                                                  ")
        self.logger.info(r"==================================================")


    def _obs(
        self,
        obs: BiLevelObservedData,
    ) -> None:

        self.logger.info(f"  x         : {obs.x}")
        self.logger.info(f"  y_upper   : {obs.y_upper}")
        self.logger.info(f"  y_lower   : {obs.y_lower}")
        if obs.c_upper is not None:
            self.logger.info(f"  c_upper   : {obs.c_upper}")
        if obs.c_lower is not None:
            self.logger.info(f"  c_lower   : {obs.c_lower}")
        self.logger.info(f"  timestamp : {obs.timestamp}")
        self.logger.info(f"  metadata  : {obs.metadata}")


    def _model(
        self,
        model: RFFHybridModel | RFFModelList,
    ) -> None:

        for name, param in model.model.named_parameters():
            c = model.model.constraint_for_parameter_name(name)
            parts = name.replace("raw_", "").split(".")
            name = ".".join([parts[0], parts[1], parts[-1]])
            if c is not None:
                self.logger.info(f"  {name:20} : {c.transform(param).data}")
            else:
                self.logger.info(f"  {name:20} : {param.data}")


    def divider(
        self,
    ) -> None:

        self.logger.info(r"--------------------------------------------------")


    def experiment(
        self,
    ) -> None:

        self.logger.info(f"CONFIG : {self.local_path.parent.parent.stem}")
        self.logger.info(f"NAME   : {self.local_path.parent.stem}")
        self.logger.info(f"SEED   : {self.local_path.stem}")


    def obs(
        self,
        data: list[BiLevelObservedData],
    ) -> None:

        self.divider()
        for i, obs in enumerate(data):
            self.logger.info(f"[BiLevelObservedData: {i:2d}]")
            self._obs(obs)
        self.divider()


    def model(
        self,
        models: list[RFFHybridModel | RFFModelList],
    ) -> None:

        self.divider()
        for model in models:
            self.logger.info(f"[{model.model.__class__.__name__}]")
            self._model(model)
        self.divider()


    def info(
        self,
        message: str,
    ) -> None:

        self.logger.info(message)

   
    def start(self) -> None:

        self._start_ascii()
        self.experiment()
        self.divider()
        self._save_config(self.config)
        self.start_time = time.perf_counter()


    def end(self) -> None:

        self.end_time = time.perf_counter()
        elapsed = self.end_time - self.start_time
        self.divider()
        self.experiment()
        self.divider()
        self.logger.info(f"TIME   : {elapsed} s")
        self._end_ascii()

        for handler in self.logger.handlers:
            handler.close()
            self.logger.removeHandler(handler)

