import os
from collections import defaultdict
from typing import List, Optional

import wandb
from PIL import Image as PILImage
from stable_baselines3.common.callbacks import (BaseCallback, CallbackList,
                                                CheckpointCallback)
from stable_baselines3.common.utils import safe_mean
from tqdm import tqdm
from wandb.integration.sb3 import WandbCallback


class ProgressCallback(BaseCallback):        

    def _on_training_start(self) -> None:
        super()._on_training_start()
        self.tqdm = tqdm(range(int(self.locals["total_timesteps"])))

    def _on_step(self) -> bool:
        self.tqdm.update()

class RenderCallback(BaseCallback):
    def __init__(
        self,
        n_steps: int = 100,
        fps: int = 100000,
        loop: Optional[int] = None,
        directory: str = "replay",
        save_last: bool = False,
        verbose: int = 0,
    ):
        super().__init__(verbose=verbose)

        self.n_steps = n_steps
        self.last_time_trigger = 0

        self.directory = directory
        self.fps = fps
        self.loop = loop
        self.save_last = save_last
        self.renderings = []

    def _on_step(self) -> bool:
        if (self.num_timesteps - self.last_time_trigger) >= self.n_steps:
            self.last_time_trigger = self.num_timesteps
            env = self.locals["env"]
            img = env.render()
            self.renderings.append(PILImage.fromarray(img))

        return True

    def _on_training_end(self) -> None:
        file_name = os.path.join(self.directory, f"{self.locals['self'].policy_class.__name__}")
        kwargs = {
            "format": "GIF",
            "save_all": True,
            "duration": 1 / self.fps,
            "append_images": self.renderings[1:],
        }
        if self.loop is not None:
            kwargs.update({"loop": self.loop})
        self.renderings[0].save(
            f"{file_name}.gif", **kwargs
        )
        if self.save_last:
            self.renderings[-1].save(f"{file_name}_last.png")

class WANDBRenderCallback(BaseCallback):
    def __init__(
        self,
        n_steps: int = 1,
        verbose: int = 0,
    ):
        super().__init__(verbose=verbose)

        self.n_steps = n_steps
        self.last_time_trigger = 0

    def _on_step(self) -> bool:
        if (self.num_timesteps - self.last_time_trigger) >= self.n_steps:
            self.last_time_trigger = self.num_timesteps
            env = self.locals["env"]
            img = env.render()
            wandb.log({"env" : wandb.Image(PILImage.fromarray(img))})

        return True

class TBInapplicableActionsCallback(BaseCallback):

    def _on_step(self) -> bool:
        """
        :return: If the callback returns False, training is aborted early.
        """
        logger = self.locals["self"].logger

        dones = self.locals["dones"]
        infos = self.locals["infos"]

        if dones[0]:
            logger.record("rollout/total_inapplicable_actions", safe_mean([i["total_inapplicable_actions"] for i in infos]))
            logger.record("rollout/inapplicable_actions", safe_mean([i["inapplicable_actions"] for i in infos]))

        return True

class TBInapplicableActionsTypesCallback(BaseCallback):

    def _on_step(self) -> bool:
        """
        :return: If the callback returns False, training is aborted early.
        """
        logger = self.locals["self"].logger

        dones = self.locals["dones"]
        infos = self.locals["infos"]

        if dones[0]:
            data = defaultdict(list)
            for i in infos:
                for k, v in i.items():
                    if k.startswith("inapplicable/"):
                        k = k.replace("inapplicable/", "")
                        data[k].append(v)
            for k, v in data.items():
                logger.record(f"rollout/inapplicable/{k}", safe_mean(v))

        return True

class WANDBCallbackList(CallbackList):

    def __init__(self, callbacks: List[BaseCallback], *args, **kwargs):
        callbacks.append(WandbCallback(*args, **kwargs))
        super(WANDBCallbackList, self).__init__(callbacks)

class ActionCountCallback(BaseCallback):
    
    def _on_step(self) -> bool:
        """
        :return: If the callback returns False, training is aborted early.
        """
        logger = self.locals["self"].logger

        dones = self.locals["dones"]
        infos = self.locals["infos"]

        if dones[0]:
            actions_count = defaultdict(list)
            for i in infos:
                for a, c in i["actions_counter"].items():
                    actions_count[a].append(c)
            
            for a, c in actions_count.items():
                logger.record(f"rollout/action_{a}_count", safe_mean(c))

        return True

class WastefulActionsCallback(BaseCallback):

    def _on_step(self) -> bool:
        """
        :return: If the callback returns False, training is aborted early.
        """
        logger = self.locals["self"].logger

        dones = self.locals["dones"]
        infos = self.locals["infos"]

        if dones[0]:
            logger.record("rollout/total_wasteful_actions", safe_mean([i["total_bad_actions"] - i["total_inapplicable_actions"] for i in infos]))
            logger.record("rollout/wasteful_actions", safe_mean([i["bad_actions"] - i["inapplicable_actions"] for i in infos]))
        
        return True

class BadActionsCallback(BaseCallback):
    
    def _on_step(self) -> bool:
        """
        :return: If the callback returns False, training is aborted early.
        """
        logger = self.locals["self"].logger

        dones = self.locals["dones"]
        infos = self.locals["infos"]

        if dones[0]:
            logger.record("rollout/total_bad_actions", safe_mean([i["total_bad_actions"] for i in infos]))
            logger.record("rollout/bad_actions", safe_mean([i["bad_actions"] for i in infos]))
        
        return True

class ClassifierCheckpointCallback(CheckpointCallback):

    def _on_step(self) -> bool:
        super()._on_step()

        if self.n_calls % self.save_freq == 0:
            path = os.path.join(self.save_path, '{}_classifier_{}_steps'.format(self.name_prefix, self.num_timesteps))
            self.model.classifier.save(path)
            if self.verbose > 1:
                print("Saving classifier model checkpoint to {}".format(path))

        return True
