from collections import deque
import os
import numpy as np
import wandb
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3 import PPO
from demo_collection.utils.utils import logging
from iq_learn.utils.utils import gen_frame, save_video

class CustomCallback(BaseCallback):
    def __init__(self, logger, log_freq, ck_save_freq, ck_save_dir, eval_env, video_save_dir, n_eval_episodes, args):
        super(CustomCallback, self).__init__()
        self.wandb_logger = logger

        self.episode_rewards = deque(maxlen=10)
        self.episode_lengths = deque(maxlen=10)
        self.successes = deque(maxlen=100)
        self.log_freq = log_freq
        self.ck_save_freq = ck_save_freq
        self.ck_save_dir = ck_save_dir
        self.eval_env = eval_env
        self.eval_video_saving_freq = args.eval_video_saving_freq
        self.video_save_dir = video_save_dir
        self.video_path = None
        self.n_eval_episodes = n_eval_episodes # for eval video
        self.args = args

    def eval_video_save(self):
        if self.num_timesteps % self.eval_video_saving_freq == 0:
            # get policy
            agent_policy = self.model.policy
            # get env
            env = self.eval_env
            frame_buffer = []

            # evaluate the policy
            for _ in range(self.n_eval_episodes):
                obs = env.reset()
                done = False
                while not done:
                    action, _ = agent_policy.predict(obs)
                    obs, reward, done, info = env.step(action)
                    frame_buffer.append(gen_frame(env.render('rgb_array'), true_reward=reward))
                
            # save frames as video
            self.video_path = save_video(self.video_save_dir, np.array(frame_buffer), episode_id=self.num_timesteps)
            logging(f"Video saved at {self.video_path}")
    
    def wandb_step_log(self):
        if self.num_timesteps % self.log_freq == 0:
            log_dict = {}
            # Log the lr
            # if self.model.policy.optimizer is not None:
            #     lr = self.model.policy.optimizer.param_groups[0]['lr']
            #     log_dict["train/lr"] = lr
            
            # Calculate the mean return and length
            log_dict["train/ep_return"] = np.mean(self.episode_rewards)
            log_dict["train/ep_length"] = np.mean(self.episode_lengths)
            log_dict["train/ep_success"] = np.mean(self.successes)

            # If log dict is not empty, log to wandb
            self.wandb_logger.wandb_log(log_dict, self.num_timesteps)

            if self.video_path is not None:
                self.wandb_logger.wandb_log_video(self.video_path, self.num_timesteps, "train/video")
                self.video_path = None

    def ck_save(self):
        if self.num_timesteps % self.ck_save_freq == 0:
            save_path = os.path.join(self.ck_save_dir, f"ppo_{self.args.env_name}_{self.num_timesteps}")
            self.model.save(save_path)
            logging(f"Model saved as '{save_path}'.")

    def _on_step(self) -> bool:
        # Store the episode return and length
        infos = self.locals.get("infos", [])
        dones = self.locals.get("dones", [])
        for info in infos:
            if "episode" in info:
                # Extract episode return and length
                ep_return = info["episode"]["r"]
                ep_length = info["episode"]["l"]
                self.episode_rewards.append(ep_return)
                self.episode_lengths.append(ep_length)
        
        for i, done in enumerate(dones):
            if done:
                is_success = infos[i]["ep_found_goal"]
                self.successes.append(is_success)

        self.eval_video_save()
        self.wandb_step_log()
        self.ck_save()
        return True

    
    def wandb_rollout_log(self, log_dict):
        if self.num_timesteps % self.log_freq == 0:
            # If log dict is not empty, log to wandb
            self.wandb_logger.wandb_log(log_dict, self.num_timesteps)

    def _on_rollout_end(self) -> None:
        # The logger stores PPO's internal metrics under self.model.logger.name_to_value
        # e.g., 'train/entropy_loss', 'train/approx_kl', 'train/policy_gradient_loss', etc.
        logs = self.model.logger.name_to_value.copy()

        # Log to wandb
        # wandb.log(logs, step=self.model.num_timesteps)
        self.wandb_rollout_log(logs)

        # Optionally print them
        if self.verbose > 0:
            print("Rollout end logs:", logs)




