import os

from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy
import wandb

import utils



class EvalAndSaveCallback(BaseCallback):
    def __init__(self, save_freq: int, eval_freq, env, env_type, save_path: str, name_prefix: str = "rl_model",
                 verbose: bool = False, use_wandb=True):
        super(EvalAndSaveCallback, self).__init__(verbose)
        self.save_freq = save_freq
        self.eval_freq = eval_freq
        self.save_path = os.path.join(save_path, 'model_checkpoints/')
        self.name_prefix = name_prefix
        self.env = env

        self.use_wandb = use_wandb
        self.env_type = env_type
        self.n_eval_episodes = 10

    def _init_callback(self) -> None:
        # Create folder if needed
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self) -> bool:
        if self.n_calls % self.eval_freq == 0:
            # mean_reward, std_reward = evaluate_policy(self.model, self.env, n_eval_episodes=10, deterministic=False, return_episode_rewards=True)
            rewards, _ = evaluate_policy(self.model, self.env, n_eval_episodes=self.n_eval_episodes,
                                         deterministic=False, return_episode_rewards=True)
            avg_reward, std_reward, avg_norm_reward, std_norm_reward = utils.get_eval_statistics(rewards, self.env_type)
            if self.verbose:
                print(
                    f"Step {self.n_calls}: Evaluation over {self.n_eval_episodes} episodes: {avg_reward:.3f} +- {std_reward:.3f}, Normalized score = {avg_norm_reward:.3f} +- {std_norm_reward:.3f}")
            if self.use_wandb:
                wandb.log(
                    {'eval mean reward': avg_reward, 'avg_norm_reward': avg_norm_reward, 'eval std reward': std_reward,
                     'epochs': self.n_calls})

        if self.n_calls % self.save_freq == 0:
            mean_reward, std_reward = evaluate_policy(self.model, self.env, n_eval_episodes=10, deterministic=False)
            path = os.path.join(self.save_path,
                                "%s_%d_steps_%d_reward" % (self.name_prefix, self.num_timesteps, mean_reward))
            self.model.save(path)
            if self.verbose:
                print(f"Saving model checkpoint to {path}")
        return True


class SB3Wrapper:
    def __init__(self, env, eval_env, config, agent_path, evaluations_path):
        self.eval_env = eval_env
        self.env = env
        self.config = config
        self.agent_path = agent_path
        self.evaluations_path = evaluations_path
        self.algorithm = config.train.algorithm
        self.gym_type = config.gym_type

        # Added this for evaluating on env with hidden dims as well
        for transform in config.simulator.transform_list:
            if 'obs_hidden_dims' in transform[0]:
                if self.gym_type == 'gym':
                    import sim
                    self.eval_env, _ = sim.get_hidden_dims_env(eval_env, transform[1], '')
                else:
                    import highway
                    self.eval_env, _ = highway.get_hidden_dims_env(eval_env.env, transform[1], '')
                    self.eval_env = highway.HighwayWrapper(self.eval_env)
                    print(self.eval_env)
                break

        # self.checkpoint_callback = EvalAndSaveCallback(save_freq=int(1e5), eval_freq=5000, env=self.eval_env,
        self.checkpoint_callback = EvalAndSaveCallback(save_freq=int(self.config.train.save_freq), eval_freq=int(self.config.train.eval_freq), env=self.eval_env,
                                                       save_path=self.agent_path, verbose=True,
                                                       use_wandb=self.config.wandb.enable, env_type=self.config.env.eval_env)

        if self.algorithm == 'sac':
            self.agent = SAC(policy='MlpPolicy', learning_starts=self.config.train.start_timesteps, env=env, verbose=0, learning_rate=self.config.train.learning_rate, gamma=self.config.train.gamma)
        else:
            raise ValueError(f'StableBaselines3 algorithm {self.algorithm} not supported')

        print(f'Created SB3 agent with algorithm: {self.algorithm}')

    def train(self):
        self.agent.learn(total_timesteps=self.config.train.max_timesteps, callback=self.checkpoint_callback)
        self.agent.save(self.agent_path)

    def test_policy(self):
        self.agent.set_parameters(self.agent_path)
        rewards, _ = evaluate_policy(self.agent, self.env, n_eval_episodes=10,
                                     deterministic=False, return_episode_rewards=True)

        avg_reward, std_reward, avg_norm_reward, std_norm_reward = utils.get_eval_statistics(rewards,
                                                                                             self.config.env.type)
        print("---------------------------------------")
        print(f"Evaluation done. Computed on {len(rewards)} episodes: {avg_reward:.3f} +- {std_reward:.3f}, Normalized score = {avg_norm_reward:.3f} +- {std_norm_reward:.3f}")
        print("---------------------------------------")
