import numpy as np
import matplotlib.pyplot as plt
import cv2
from stable_baselines3.common.logger import Figure
from IPython import embed
from stable_baselines3.common.callbacks import BaseCallback

class LogCoverageCallback(BaseCallback):
    def __init__(self, log_heatmap_every=1, map_size=32, verbose=0):
        super(LogCoverageCallback, self).__init__(verbose)

        self.mean_ep_coverage = []
        self.mean_ep_task_reward = []
        self.mean_ep_coverage_in_global = []
        self.mean_ep_visited_biomes = []

        self.log_count = 0
        self.log_heatmap_every = log_heatmap_every

        self.map_size = map_size

    def _on_step(self) -> bool:
        dones = self.locals["dones"]
        self.log_count += 1

        
        for traj in range(dones.shape[0]):
            if dones[traj]:
                if "coverage" in self.locals["infos"][0].keys():
                    self.mean_ep_coverage.append(self.locals["infos"][traj]["coverage"])
                
                if "task_reward" in self.locals["infos"][0].keys():
                    self.mean_ep_task_reward.append(self.locals["infos"][traj]["task_reward"])

                if "episodic_coverage" in self.locals["infos"][0].keys():
                    self.mean_ep_coverage_in_global.append(self.locals["infos"][traj]["episodic_coverage"])

                if "visited_biomes" in self.locals["infos"][0].keys():
                    self.mean_ep_visited_biomes.append(self.locals["infos"][traj]["visited_biomes"])

                if self.log_heatmap_every > 0 and self.log_count > self.log_heatmap_every:
                    try:
                        cmap = plt.get_cmap('Greens')
                        cmap.set_under((0,0,0,0))
                        cmap_args = dict(cmap=cmap, vmin=1)
                        
                        fig = plt.figure(num=1, clear=True)
                        #background_img = self.locals["infos"][traj]["background"]
                        background_img = self.training_env.render(mode="rgb_array")
                        
                        pixel_size = self.map_size * 8

                        if background_img.shape[0] > pixel_size:
                            background_img = background_img[0:pixel_size, 0:pixel_size]

                        heatmap = self.locals["infos"][traj]["heatmap"].transpose(1,0)
                        heatmap = np.sign(heatmap - 1)

                        background = cv2.resize(background_img, dsize=(self.map_size, self.map_size), interpolation=cv2.INTER_AREA)
                        plt.imshow(background, alpha=1)
                        plt.imshow(heatmap, **cmap_args, interpolation='nearest')
                        plt.xticks([])
                        plt.yticks([])
                        plt.colorbar()
                        self.logger.record(f"trajectory/heatmap_{traj}", Figure(fig, close=True), exclude=("stdout", "log", "json", "csv"))
                        plt.close(fig)
                        plt.clf()
                    except:
                        pass
                    
                    if traj == dones.shape[0] - 1:
                        self.log_count = 0
                    
            

    def _on_rollout_end(self) -> None:
        if len(self.mean_ep_coverage) > 0:
            self.logger.record("train/avg_coverage", np.mean(self.mean_ep_coverage))
            self.logger.record("train/avg_ep_task_reward", np.mean(self.mean_ep_task_reward))
        
        if len(self.mean_ep_coverage_in_global) > 0:
            self.logger.record("train/avg_ep_coverage", np.mean(self.mean_ep_coverage_in_global))
        
        if len(self.mean_ep_visited_biomes) > 0:
            self.logger.record("train/avg_visited_biomes", np.mean(self.mean_ep_visited_biomes))

        self.mean_ep_coverage.clear()
        self.mean_ep_task_reward.clear()
        self.mean_ep_coverage_in_global.clear()
        self.mean_ep_visited_biomes.clear()

class LogRewardsGodotCallback(BaseCallback):
    def __init__(self, log_every=3000, num_envs=320, verbose=0):
        super(LogRewardsGodotCallback, self).__init__(verbose)

        self.mean_ep_task_reward = []
        self.mean_ep_length = []
        self.mean_coverage = []
        self.episode_rewards = np.zeros((num_envs,))
        self.episode_lengths = np.zeros((num_envs,))
        
        self.log_count = 0
        self.log_every = log_every

    def _on_step(self) -> bool:
        self.log_count += 1

        dones = self.locals["dones"]
        infos = self.locals["infos"]
        self.episode_rewards += self.locals["rewards"]
        self.episode_lengths += 1

        for traj in range(len(infos)):
            if dones[traj]:
                self.mean_coverage.append(infos[traj]["coverage"])
                self.mean_ep_task_reward.append(self.episode_rewards[traj])
                self.mean_ep_length.append(self.episode_lengths[traj])
                
                self.episode_rewards[traj] = 0
                self.episode_lengths[traj] = 0

        if self.log_count > self.log_every:
            self.logger.record("train/avg_episode_reward", np.mean(self.mean_ep_task_reward))
            self.logger.record("train/avg_episode_length", np.mean(self.mean_ep_length))
            self.logger.record("train/coverage", np.mean(self.mean_coverage))
            self.logger.record("train/num_episodes_in_batch", len(self.mean_coverage))

            self.mean_ep_task_reward.clear()
            self.mean_coverage.clear()
            self.mean_ep_length.clear()

            self.log_count = 0