#from stable_baselines3 import PPO, A2C, DQN
#from stable_baselines3 import SAC, TD3
#from sb3_contrib import TQC

import stable_baselines3
import sb3_contrib
from sb3_contrib import *
from stable_baselines3 import *
from stable_baselines3.common.callbacks import EvalCallback, BaseCallback
from stable_baselines3.common.logger import Image, HParam, Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat
from stable_baselines3.common.monitor import Monitor

import gymnasium as gym
import numpy as np

from molecule_movement.cli_parsing import parse, create_parser

from molecule_movement.wrapper import (
    CorridorWrapper,
    NormalizeDistanceRewardWrapper,
    ManipulateCurrentMoleculeWrapper,
    ActionSpaceClipper,
    DistanceRewardBuffWrapper,
    PerStepCostWrapper,
    PositionRewardWrapper,
    MovementRewardWrapper,
    RotationRewardWrapper,
    ReorientationRewardWrapper,
    StatisticsInfoWrapper
    )
from gymnasium.wrappers import TransformReward

import importlib
import sys
import os
import datetime

from loguru import logger
from molecule_movement.logging import log_and_raise

class ImageRecorderCallback(BaseCallback):
    def __init__(self, env, verbose: int = 0):
        super().__init__(verbose)
        self.env = env

    def _on_training_start(self):
        try:
            img = self.env.unwrapped.get_image()
            if isinstance(img, np.ndarray):
                self.logger.record("info/image", Image(img, "HWC"), exclude=("stdout", "log", "json", "csv"))
            else:
                logger.warning(f"Could not log an image: {img=} in not a np.ndarray. Can only log an image if training has been invoked with 'render_mode' == 'human' or ' rgb_array'.")
        except Exception as e:
            logger.warning("Could not log an image: {e=}.")
    def _on_step(self) -> bool:
        return True

class HParamCallback(BaseCallback):
    """
    Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.
    """
    def __init__(self, env, verbose: int = 0):
        super().__init__(verbose)
        self.env = env
        self.__collect_wrapper_specs()

    def _on_training_start(self) -> None:
        hparam_dict = {
            "algorithm": self.model.__class__.__name__,
            "specs": self.wrapper_specs
        }
        # define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag
        # Tensorbaord will find & display metrics from the `SCALARS` tab
        metric_dict = {
            "train/value_loss": 0.0,
            "info/mean_movement_travelled": 0.0,
            "info/inside_corridor_mean": 0.0,
            "rollout/ep_rew_mean": 0.0,
            "rollout/ep_len_mean": 0,
        }
        try:
            self.logger.record(
                "hparams",
                HParam(hparam_dict, metric_dict),
                exclude=("stdout", "log", "json", "csv"),
            )
        except Exception as e:
            logger.warning(e)
        self.logger.record("info/wrapper_specs", self.wrapper_specs)

    def _on_step(self) -> bool:
        return True

    def __collect_wrapper_specs(self):
        specs = list()
        for wrapper_spec in self.env.spec.additional_wrappers:
            try:
                for k, v in wrapper_spec.kwargs.items():
                    specs.append(f"{k}:{v}")
            except AttributeError:
                pass
        self.wrapper_specs = ",".join(specs)




class InfoCallback(BaseCallback):
    """
    Custom callback for plotting additional values in tensorboard.
    """

    def __init__(self, verbose=0):
        super().__init__(verbose)
        self.reached = 0
        self.oriented = 0
        self.crashed = 0

    def _on_step(self) -> bool:
        try:
            infos = self.locals["infos"][0]
            if infos["reached_goal"]:
                self.reached += 1
            if infos["reached_goal_orientation"]:
                self.oriented += 1
            if infos["crashed"]:
                self.crashed += 1
            self.logger.record("info/sum_reached_goal", self.reached)
            self.logger.record("info/sum_oriented", self.oriented)
            self.logger.record("info/sum_crashed", self.crashed)
            try:
                self.logger.record("info/mean_movement_travelled", infos["mean_movement_travelled"])
            except KeyError as e:
                pass
            try:
                self.logger.record("info/inside_corridor_mean", infos["inside_corridor_mean"])
            except KeyError as e:
                pass
            try:
                self.logger.record("info/mean_movement_reward", infos["mean_movement_reward"])
            except KeyError as e:
                pass
            try:
                self.logger.record("info/mean_corridor_penalty", infos["mean_corridor_penalty"])
            except KeyError as e:
                pass
            try:
                self.logger.record("info/mean_reorientation_reward", infos["mean_reorientation_reward"])
            except KeyError as e:
                pass
        except KeyError as e:
            pass
        return True


logger.configure(handlers=[{"sink": sys.stderr, "level": "INFO"}])
logger.enable("molecule_movement")

def make_env(args, human_render: bool = False):
    env = gym.make(args.env,
                   render_mode=args.render_mode if human_render else "none",
                   scale=args.scale,
                   render_grid=args.render_grid,
                   render_sensors=args.render_sensors,
                   num_sensors=args.sensors,
                   max_steps=args.max_steps)
    env = MovementRewardWrapper(env, weight=1.0)
    env = ReorientationRewardWrapper(env, weight=0.5)
    env = PerStepCostWrapper(env, 0.1)
    env = CorridorWrapper(env, corridor_width=args.corridor_width, parking_buffer=1.0, parking_distance=5.0, corridor_violation_penalty=-1.5)
    env = ManipulateCurrentMoleculeWrapper(env, (-1.5, 1.5, 0.3), (-1.5, 1.5, 0.3))
    env = StatisticsInfoWrapper(env, columns=[], sliding_window_size=1000, check=False)
    env = Monitor(env)
    return env

def list_of_algorithms() -> list[str]:
    sb3_algorithms = stable_baselines3.__all__
    sb3_algorithms.remove("HerReplayBuffer")
    sb3_algorithms.remove("get_system_info")
    sb3_algorithms.remove("DQN")
    sb3_contrib_algorithms = sb3_contrib.__all__
    sb3_contrib_algorithms.remove("QRDQN")
    sb3_contrib_algorithms.remove("MaskablePPO")
    sb3_contrib_algorithms.remove("RecurrentPPO")
    sb3_algorithms.extend(sb3_contrib_algorithms)
    return sb3_algorithms

def create_training_parser():
    parser = create_parser()
    parser.add_argument(
        "--algo",
        type=str,
        help="stable_baselines3 algorithm for training",
        choices=list_of_algorithms(),
        default="TQC",
    )
    parser.add_argument(
        "--total-timesteps",
        type=int,
        help="total timesteps for training",
        default=200_000,
    )
    parser.add_argument(
        "--dir",
        type=str,
        help="directory for the tensorboard logging output and learned policies.",
        default="./trained_policies",
    )
    return parser

def str_to_class(module_name: str, class_name: str):
    """Return a class instance from a string reference"""
    try:
        module_ = importlib.import_module(module_name)
        try:
            class_ = getattr(module_, class_name)
        except AttributeError:
            logger.trace('Class does not exist')
    except ImportError:
        logger.trace('Module does not exist')
    return class_ or None

def get_algorithm(algorithm_name: str):
    try:
        return str_to_class("stable_baselines3", algorithm_name)
    except Exception:
        try:
            return str_to_class("sb3_contrib", algorithm_name)
        except AttributeError as e:
            log_and_raise(e, f"Could not instantiate algorithm class for {algorithm_name}")


def save_image(env):
    env.reset()
    try:
        img = Image.fromarray(get_image(env))
        img.save("foobar.png")
    except Exception:
        logger.warning("Could not save an image")


def main():
    parser = create_training_parser()
    args = parser.parse_args()
    env = make_env(args, human_render=True)
    eval_env = make_env(args, human_render=False)
    obs, _ = env.reset(seed=args.seed)

    log_dir = os.path.join(args.dir, f"{args.algo}_{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{args.env.replace('/','_')}_{args.corridor_width}")
    os.makedirs(log_dir, exist_ok=True)


    sb3_logger = Logger(log_dir, output_formats=[TensorBoardOutputFormat(log_dir), HumanOutputFormat(sys.stdout), HumanOutputFormat(f"{log_dir}/out.log"), CSVOutputFormat(f"{log_dir}/plot.csv")])
    eval_callback = EvalCallback(eval_env, best_model_save_path=log_dir,
                                log_path=log_dir, eval_freq=10000,
                                n_eval_episodes=5, deterministic=True,
                                render=True)

    model = get_algorithm(args.algo)("MlpPolicy", env, device="auto", verbose=1)
    model.set_logger(sb3_logger)
    model.learn(total_timesteps=args.total_timesteps, callback=[eval_callback, InfoCallback(), HParamCallback(env)])
    model.save(f"{log_dir}/{args.algo}_{args.total_timesteps}_{args.corridor_width}")

if __name__ == "__main__":
    main()

