from collections import defaultdict
from typing import Optional, Type, Union

import gym
import numpy as np
import torch as th
import torch.nn as nn
from custom_minigrid.envs.maze import SimpleMazeEnv
from rl_inapplicable_actions.utils.pre_conditions import \
    CNNGridWorldActionPreConditionWrapper
from rl_inapplicable_actions.utils.wrappers import AccessImageObsWrapper
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.callbacks import (BaseCallback, CallbackList,
                                                ConvertCallback, EvalCallback)
from stable_baselines3.common.logger import Image
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import (explained_variance,
                                            get_schedule_fn, obs_as_tensor)
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.ppo.ppo import PPO
from torch.nn import functional as F

from .buffer import MaskRolloutBuffer
from .policy import ActorCriticPolicyClassifierWrapper


class CustomEvalCallback(EvalCallback):

    def _on_step(self) -> bool:
        continue_training = super()._on_step()

        if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:

            dataset = [] # [(obs, pos, action, valid)]

            env = CNNGridWorldActionPreConditionWrapper(AccessImageObsWrapper(self.eval_env.envs[0].unwrapped.copy()))

            for pos in env.valid_positions:
                env.unwrapped.agent_start_pos = pos
                obs = env.reset()
                for action in env.actions:
                    is_applicable = env.is_applicable(obs, action)
                    dataset.append((obs, pos, action, is_applicable))


            results = defaultdict(lambda: 100 * np.ones((7, 7, 3)))

            for obs, pos, action, y in dataset:
                th_obs = th.tensor(obs, dtype=th.float).unsqueeze(0)
                th_action = th.tensor(action, dtype=th.long).unsqueeze(0)
                y_hat = self.model.classifier(th_action, th_obs)
                y_hat_pred = (th.sigmoid(y_hat).squeeze(-1) > 0.5).float().item()
                results[action][pos[1], pos[0], :] = [0., 153., 0.] if y_hat_pred else [204., 0., 0.]

            for action, result in results.items():
                action_name = {v: k for k, v in env.actions.__members__.items()}.get(action)

                self.logger.record(
                    f"eval/classifier_results_action_{action_name}", 
                    Image(result / 255., "HWC")
                )

            self.logger.dump(self.num_timesteps)

        return continue_training

class PPOClassifier(PPO):
    def __init__(
        self,
        policy: Union[str, Type[ActorCriticPolicy]],
        *args,
        buffer_class=MaskRolloutBuffer,
        classifier=None,
        train_classifier=True,
        balanced_sampling=True,
        **kwargs,
    ):
        ## Classifier
        self.classifier = classifier
        self.train_classifier = train_classifier
        self.balanced_sampling = balanced_sampling
        self.buffer_class = buffer_class or (
            DictRolloutBuffer
            if isinstance(self.observation_space, gym.spaces.Dict)
            else RolloutBuffer
        )
        ## -

        self.is_cnn_classifier = isinstance(self.classifier, nn.Module)
        policy = ActorCriticPolicyClassifierWrapper(
            policy, self.is_cnn_classifier
        )

        super().__init__(policy, *args, **kwargs)

    def _setup_model(self) -> None:
        # On-Policy._setup_model
        self._setup_lr_schedule()
        self.set_random_seed(self.seed)

        self.rollout_buffer = self.buffer_class(
            5000,  # self.n_steps,
            self.observation_space,
            self.action_space,
            device=self.device,
            gamma=self.gamma,
            gae_lambda=self.gae_lambda,
            n_envs=self.n_envs,
        )
        self.policy = self.policy_class(  # pytype:disable=not-instantiable
            self.observation_space,
            self.action_space,
            self.lr_schedule,
            use_sde=self.use_sde,
            classifier=self.classifier,
            **self.policy_kwargs,  # pytype:disable=not-instantiable
        )
        self.policy = self.policy.to(self.device)

        # PPO._setup_model
        # Initialize schedules for policy/value clipping
        self.clip_range = get_schedule_fn(self.clip_range)
        if self.clip_range_vf is not None:
            if isinstance(self.clip_range_vf, (float, int)):
                assert self.clip_range_vf > 0, (
                    "`clip_range_vf` must be positive, "
                    "pass `None` to deactivate vf clipping"
                )

            self.clip_range_vf = get_schedule_fn(self.clip_range_vf)

    def _init_callback(
        self,
        callback: MaybeCallback,
        eval_env: Optional[VecEnv] = None,
        eval_freq: int = 10000,
        n_eval_episodes: int = 5,
        log_path: Optional[str] = None,
    ) -> BaseCallback:
        """
        :param callback: Callback(s) called at every step with state of the algorithm.
        :param eval_freq: How many steps between evaluations; if None, do not evaluate.
        :param n_eval_episodes: How many episodes to play per evaluation
        :param n_eval_episodes: Number of episodes to rollout during evaluation.
        :param log_path: Path to a folder where the evaluations will be saved
        :return: A hybrid callback calling `callback` and performing evaluation.
        """
        # Convert a list of callbacks into a callback
        if isinstance(callback, list):
            callback = CallbackList(callback)

        # Convert functional callback to object
        if not isinstance(callback, BaseCallback):
            callback = ConvertCallback(callback)

        # Create eval callback in charge of the evaluation
        if eval_env is not None:
            eval_callback = CustomEvalCallback(
                eval_env,
                best_model_save_path=log_path,
                log_path=log_path,
                eval_freq=eval_freq,
                n_eval_episodes=n_eval_episodes,
            )
            callback = CallbackList([callback, eval_callback])

        callback.init_callback(self)
        return callback

    def collect_rollouts(
        self,
        env: VecEnv,
        callback: BaseCallback,
        rollout_buffer: RolloutBuffer,
        n_rollout_steps: int,
    ) -> bool:
        assert (
            self._last_obs is not None
        ), "No previous observation was provided"
        # Switch to eval mode (this affects batch norm / dropout)
        self.policy.set_training_mode(False)

        n_steps = 0
        rollout_buffer.reset()
        # Sample new weights for the state dependent exploration
        if self.use_sde:
            self.policy.reset_noise(env.num_envs)

        callback.on_rollout_start()

        while n_steps < n_rollout_steps:
            if (
                self.use_sde
                and self.sde_sample_freq > 0
                and n_steps % self.sde_sample_freq == 0
            ):
                # Sample a new noise matrix
                self.policy.reset_noise(env.num_envs)

            with th.no_grad():
                # Convert to pytorch tensor or to TensorDict
                obs_tensor = obs_as_tensor(self._last_obs, self.device)
                actions, values, log_probs, mask = self.policy.forward(obs_tensor)
            actions = actions.cpu().numpy()

            # Rescale and perform action
            clipped_actions = actions
            # Clip the actions to avoid out of bound error
            if isinstance(self.action_space, gym.spaces.Box):
                clipped_actions = np.clip(
                    actions, self.action_space.low, self.action_space.high
                )

            new_obs, rewards, dones, infos = env.step(clipped_actions)

            ## Classifier
            is_action_valids = np.asarray(
                not np.array_equal(self._last_obs, infos[0].get("terminal_observation", new_obs))
            )
            ## -

            self.num_timesteps += env.num_envs

            # Give access to local variables
            callback.update_locals(locals())
            if callback.on_step() is False:
                return False

            self._update_info_buffer(infos)
            n_steps += 1

            if isinstance(self.action_space, gym.spaces.Discrete):
                # Reshape in case of discrete action
                actions = actions.reshape(-1, 1)

            # Handle timeout by bootstraping with value function
            # see GitHub issue #633
            for idx, done in enumerate(dones):
                if (
                    done
                    and infos[idx].get("terminal_observation") is not None
                    and infos[idx].get("TimeLimit.truncated", False)
                ):
                    terminal_obs = self.policy.obs_to_tensor(
                        infos[idx]["terminal_observation"]
                    )[0]
                    with th.no_grad():
                        terminal_value = self.policy.predict_values(
                            terminal_obs
                        )[0]
                    rewards[idx] += self.gamma * terminal_value

            ## Classifier
            rollout_buffer.add(
                self._last_obs,
                actions,
                rewards,
                self._last_episode_starts,
                values,
                log_probs,
                is_action_valids,
                mask.float()
            )
            ## -
            self._last_obs = new_obs
            self._last_episode_starts = dones

        with th.no_grad():
            # Compute value for the last timestep
            values = self.policy.predict_values(
                obs_as_tensor(new_obs, self.device)
            )

        rollout_buffer.compute_returns_and_advantage(
            last_values=values, dones=dones
        )

        callback.on_rollout_end()

        return True

    def train(self) -> None:
        """
        Update policy using the currently gathered rollout buffer.
        """
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)

        ## Classifier
        self.policy.update_exploration_rate(self._current_progress_remaining)
        ## -

        # Update optimizer learning rate
        self._update_learning_rate(self.policy.optimizer)
        # Compute current clip range
        clip_range = self.clip_range(self._current_progress_remaining)
        # Optional: clip range for the value function
        if self.clip_range_vf is not None:
            clip_range_vf = self.clip_range_vf(self._current_progress_remaining)

        entropy_losses = []
        pg_losses, value_losses = [], []
        ## Classifier
        classifier_losses = []
        class_balances = []
        ## -
        clip_fractions = []

        continue_training = True

        # train for n_epochs epochs
        for epoch in range(self.n_epochs):
            approx_kl_divs = []
            # Do a complete pass on the rollout buffer
            buffer_size = self.rollout_buffer.pos - self.rollout_buffer.reset_index
            for rollout_data in self.rollout_buffer.get(
                self.batch_size, 
                start_index=self.rollout_buffer.reset_index, 
                buffer_size=(buffer_size + self.rollout_buffer.buffer_size) if buffer_size < 0 else buffer_size
            ):
                # print(f"rollout_data length: {len(rollout_data.actions)}")
                actions = rollout_data.actions
                if isinstance(self.action_space, gym.spaces.Discrete):
                    # Convert discrete action from float to long
                    actions = rollout_data.actions.long().flatten()

                # Re-sample the noise matrix because the log_std has changed
                if self.use_sde:
                    self.policy.reset_noise(self.batch_size)

                values, log_prob, entropy = self.policy.evaluate_actions(
                    rollout_data.observations, actions, rollout_data.masks
                )
                values = values.flatten()
                # Normalize advantage
                advantages = rollout_data.advantages
                advantages = (advantages - advantages.mean()) / (
                    advantages.std() + 1e-8
                )

                # ratio between old and new policy, should be one at the first iteration
                ratio = th.exp(log_prob - rollout_data.old_log_prob)

                # clipped surrogate loss
                policy_loss_1 = advantages * ratio
                policy_loss_2 = advantages * th.clamp(
                    ratio, 1 - clip_range, 1 + clip_range
                )
                policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()

                # Logging
                pg_losses.append(policy_loss.item())
                clip_fraction = th.mean(
                    (th.abs(ratio - 1) > clip_range).float()
                ).item()
                clip_fractions.append(clip_fraction)

                if self.clip_range_vf is None:
                    # No clipping
                    values_pred = values
                else:
                    # Clip the different between old and new value
                    # NOTE: this depends on the reward scaling
                    values_pred = rollout_data.old_values + th.clamp(
                        values - rollout_data.old_values,
                        -clip_range_vf,
                        clip_range_vf,
                    )
                # Value loss using the TD(gae_lambda) target
                value_loss = F.mse_loss(rollout_data.returns, values_pred)
                value_losses.append(value_loss.item())

                # Entropy loss favor exploration
                if entropy is None:
                    # Approximate entropy when no analytical form
                    entropy_loss = -th.mean(-log_prob)
                else:
                    entropy_loss = -th.mean(entropy)

                entropy_losses.append(entropy_loss.item())

                ## Classifier
                classifier_loss = 0.0
                if self.is_cnn_classifier and self.train_classifier:
                    rollout_data_classifier = next(
                        self.rollout_buffer.get(None)
                    )

                    actions_classifier = (
                        rollout_data_classifier.actions.detach().clone()
                    )
                    y_classifier = (
                        rollout_data_classifier.is_action_valids.detach().clone()
                    )
                    if self.policy.use_policy_features:
                        observations_classifier = self.policy.extract_features(
                            rollout_data_classifier.observations.detach().clone()
                        )
                    else:
                        observations_classifier = rollout_data_classifier.observations.detach().clone().float()

                    ## Balanced sampling
                    if self.balanced_sampling:
                        def balanced_sampling(a, o, y, batch_size=None):
                            from torch.utils.data import (
                                DataLoader, TensorDataset,
                                WeightedRandomSampler)
                            classes, counts  = np.unique(y, return_counts=True)
                            class_weights = {y.item(): sum(counts) / c for y, c in zip(classes, counts)}
                            sample_weights = [class_weights[s.item()] for s in y]
                            sampler = WeightedRandomSampler(sample_weights, len(y), replacement=False)
                            data_loader = DataLoader(TensorDataset(a, o, y), sampler=sampler, batch_size=batch_size)
                            b_actions, b_observations, b_y = next(iter(data_loader))
                            return b_actions, b_observations, b_y
                        
                        (
                            actions_classifier,
                            observations_classifier,
                            y_classifier,
                        ) = balanced_sampling(
                            actions_classifier,
                            observations_classifier,
                            y_classifier,
                            batch_size=self.batch_size
                        )

                    pred = self.policy.classifier(
                        actions=actions_classifier,
                        observations=observations_classifier,
                    ).squeeze()
                    classifier_loss = self.policy.classifier.loss(
                        pred, y_classifier
                    )
                    classifier_losses.append(classifier_loss.item())
                    class_balance = y_classifier.mean()
                    class_balances.append(class_balance)
                ## -

                loss = (
                    policy_loss
                    + self.ent_coef * entropy_loss
                    + self.vf_coef * value_loss
                    + classifier_loss
                )

                # Calculate approximate form of reverse KL Divergence for early stopping
                # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
                # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
                # and Schulman blog: http://joschu.net/blog/kl-approx.html
                with th.no_grad():
                    log_ratio = log_prob - rollout_data.old_log_prob
                    approx_kl_div = (
                        th.mean((th.exp(log_ratio) - 1) - log_ratio)
                        .cpu()
                        .numpy()
                    )
                    approx_kl_divs.append(approx_kl_div)

                if (
                    self.target_kl is not None
                    and approx_kl_div > 1.5 * self.target_kl
                ):
                    continue_training = False
                    if self.verbose >= 1:
                        print(
                            f"Early stopp ing at step {epoch} due to reaching max kl: {approx_kl_div:.2f}"
                        )
                    break

                # Optimization step
                self.policy.optimizer.zero_grad()
                if self.is_cnn_classifier and self.train_classifier:
                    self.policy.classifier.optimizer.zero_grad()
                
                loss.backward()

                # Clip grad norm
                th.nn.utils.clip_grad_norm_(
                    self.policy.parameters(), self.max_grad_norm
                )
                if self.is_cnn_classifier and self.train_classifier:
                    self.policy.classifier.optimizer.step()
                self.policy.optimizer.step()

            if not continue_training:
                break

        self._n_updates += self.n_epochs
        explained_var = explained_variance(
            self.rollout_buffer.values.flatten(),
            self.rollout_buffer.returns.flatten(),
        )

        ## Classifier
        if not self.is_cnn_classifier and self.train_classifier:
            raise NotImplementedError("Non CNN Classifier resampling")
            rollout_data_classifier = next(self.rollout_buffer.get(None))
            actions_classifier = rollout_data_classifier.actions
            if isinstance(self.action_space, gym.spaces.Discrete):
                actions_classifier = th.nn.functional.one_hot(
                    actions_classifier.long(), self.action_space.n
                ).squeeze(1)

            with th.no_grad():
                if self.policy.use_policy_features:
                    observations_classifier = self.policy.extract_features(
                        rollout_data_classifier.observations
                    )
                else:
                    # TODO move reshape in the classifier
                    observations_classifier = (
                        rollout_data_classifier.observations.reshape(
                            actions_classifier.shape[0], -1
                        )
                    )
                X = th.concat(
                    [actions_classifier, observations_classifier], dim=1
                )
                y = rollout_data_classifier.is_action_valids

                self.classifier.fit(X, y)
                self.logger.record(
                    "train/classifier_score",
                    self.classifier.score(
                        X, rollout_data_classifier.is_action_valids
                    ),
                )
        ## -

        # Logs

        ## Classifier
        if classifier_losses:
            self.logger.record(
                "train/classifier_loss", np.mean(classifier_losses)
            )
        if class_balances:
            self.logger.record(
                "train/class_balance", np.mean(class_balances)
            )
        self.logger.record("train/exploration_rate", self.policy._exploration_rate)
        ## -

        self.logger.record("train/entropy_loss", np.mean(entropy_losses))
        self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
        self.logger.record("train/value_loss", np.mean(value_losses))
        self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
        self.logger.record("train/clip_fraction", np.mean(clip_fractions))
        self.logger.record("train/loss", loss.item())
        self.logger.record("train/explained_variance", explained_var)
        if hasattr(self.policy, "log_std"):
            self.logger.record(
                "train/std", th.exp(self.policy.log_std).mean().item()
            )

        self.logger.record(
            "train/n_updates", self._n_updates, exclude="tensorboard"
        )
        self.logger.record("train/clip_range", clip_range)
        if self.clip_range_vf is not None:
            self.logger.record("train/clip_range_vf", clip_range_vf)

    def learn(
        self,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 1,
        eval_env: Optional[GymEnv] = None,
        eval_freq: int = -1,
        n_eval_episodes: int = 5,
        tb_log_name: str = "PPOClassifier",
        eval_log_path: Optional[str] = None,
        reset_num_timesteps: bool = True,
    ) -> "PPOClassifier":

        return super().learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            eval_env=eval_env,
            eval_freq=eval_freq,
            n_eval_episodes=n_eval_episodes,
            tb_log_name=tb_log_name,
            eval_log_path=eval_log_path,
            reset_num_timesteps=reset_num_timesteps,
        )
