import numpy as np
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
from stable_baselines3.common.evaluation import evaluate_policy


class NamedEvalCallback(EvalCallback):
    def __init__(self, *args, name="eval", res_dict_plots, **kwargs):
        """
        Note : self.eval_env is wrapped so the custome environnement is self.eval_env.envs[0]
        """
        super().__init__(*args, **kwargs)
        self.eval_log_name = name
        self.train_or_val = name.split("_")[1] 
        self.res_dict_plots = res_dict_plots

    def _on_step(self) -> bool:
        result = super()._on_step()

        # Inject the model in the environment for the 3 validation methods
        if hasattr(self.eval_env.envs[0], "set_eval_model"):
            self.eval_env.envs[0].set_eval_model(self.model)
        else : 
            raise ValueError("The eval_env does not have a set_eval_model method. Please ensure the environment is compatible with this callback.")
        
        # Inject the res_dict_plots in the environment to save the 3 validation methods results
        if hasattr(self.eval_env.envs[0], "set_res_dict_plots"):
            self.eval_env.envs[0].set_res_dict_plots(self.res_dict_plots)
        else :
            raise ValueError("The eval_env does not have a set_res_dict method. Please ensure the environment is compatible with this callback.")

        # Renomme les logs de type "eval/" en "eval_name/"
        if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
            if self.last_mean_reward is not None:
                self.logger.record(f"{self.eval_log_name}/mean_reward", self.last_mean_reward)
                self.res_dict_plots[f"{self.train_or_val}"]
        return result

class MetricLoggingCallback(BaseCallback):
    def __init__(self, verbose=0, log_interval=1):
        super(MetricLoggingCallback, self).__init__(verbose)
        self.episode_aucpr_scores = []
        self.log_interval = log_interval
        self.episode_counter = 0

    def _on_step(self) -> bool:
        infos = self.locals.get("infos", [])
        
        for info in infos:
            if "AUC-PR" in info:
                self.episode_aucpr_scores.append(info["AUC-PR"])
                self.episode_counter += 1
                
                # Print aucpr score every log_interval episodes
                if self.episode_counter % self.log_interval == 0:
                    avg_aucpr = np.mean(self.episode_aucpr_scores[-self.log_interval:])
                    print(f"Episode {self.episode_counter}, Avg AUC-PR Score: {avg_aucpr:.4f}")
                    
        return True


class LoggingCallback(BaseCallback):
    def __init__(self, verbose=0, log_interval=1):
        super(LoggingCallback, self).__init__(verbose)
        self.episode_scores = []
        self.log_interval = log_interval
        self.episode_counter = 0

    def _on_step(self) -> bool:
        infos = self.locals.get("infos", [])

        for info in infos:
            if "AUC-PR" in info:
                self.episode_scores.append(info["AUC-PR"])
                self.episode_counter += 1

                # Print AUC-PR every log_interval episodes
                if self.episode_counter % self.log_interval == 0:
                    avg_aucpr = np.mean(self.episode_scores[-self.log_interval:])
                    print(f"Episode {self.episode_counter}, Avg AUC-PR Score: {avg_aucpr:.4f}")
                    self.logger.record("custom/avg_auc_pr", avg_aucpr)
                    
        self.logger.record("timesteps", self.num_timesteps)
        return True