import copy
import random
from typing import Union, Tuple, List, Dict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ding.model import FCEncoder, ConvEncoder
from ding.reward_model.base_reward_model import BaseRewardModel
from ding.torch_utils.data_helper import to_tensor
from ding.utils import RunningMeanStd
from ding.utils import SequenceType, REWARD_MODEL_REGISTRY
from easydict import EasyDict


class RNDNetwork(nn.Module):

    def __init__(self, obs_shape: Union[int, SequenceType], hidden_size_list: SequenceType) -> None:
        super(RNDNetwork, self).__init__()
        if isinstance(obs_shape, int) or len(obs_shape) == 1:
            self.target = FCEncoder(obs_shape, hidden_size_list)
            self.predictor = FCEncoder(obs_shape, hidden_size_list)
        elif len(obs_shape) == 3:
            self.target = ConvEncoder(obs_shape, hidden_size_list)
            self.predictor = ConvEncoder(obs_shape, hidden_size_list)
        else:
            raise KeyError(
                "not support obs_shape for pre-defined encoder: {}, please customize your own RND model".
                format(obs_shape)
            )
        for param in self.target.parameters():
            param.requires_grad = False

    def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        predict_feature = self.predictor(obs)
        with torch.no_grad():
            target_feature = self.target(obs)
        return predict_feature, target_feature


class RNDNetworkRepr(nn.Module):
    """
    Overview:
        The RND reward model class (https://arxiv.org/abs/1810.12894v1) with representation network.
    """

    def __init__(self, obs_shape: Union[int, SequenceType], latent_shape: Union[int, SequenceType],  hidden_size_list: SequenceType,
                 representation_network) -> None:
        super(RNDNetworkRepr, self).__init__()
        self.representation_network = representation_network
        if isinstance(obs_shape, int) or len(obs_shape) == 1:
            self.target = FCEncoder(obs_shape, hidden_size_list)
            self.predictor = FCEncoder(latent_shape, hidden_size_list)
        elif len(obs_shape) == 3:
            self.target = ConvEncoder(obs_shape, hidden_size_list)
            self.predictor = ConvEncoder(latent_shape, hidden_size_list)
        else:
            raise KeyError(
                "not support obs_shape for pre-defined encoder: {}, please customize your own RND model".
                format(obs_shape)
            )
        for param in self.target.parameters():
            param.requires_grad = False

    def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        predict_feature = self.predictor(self.representation_network(obs))
        with torch.no_grad():
            target_feature = self.target(obs)

        return predict_feature, target_feature


@REWARD_MODEL_REGISTRY.register('rnd_muzero')
class RNDRewardModel(BaseRewardModel):
    """
    Overview:
        The RND reward model class (https://arxiv.org/abs/1810.12894v1) modified for MuZero.
    Interface:
        ``estimate``, ``train``, ``collect_data``, ``clear_data``, \
            ``__init__``, ``_train``, ``load_state_dict``, ``state_dict``
    Config:
        == ====================  =====  =============  =======================================  =======================
        ID Symbol                Type   Default Value  Description                              Other(Shape)
        == ====================  =====  =============  =======================================  =======================
        1   ``type``              str     rnd          | Reward model register name, refer      |
                                                       | to registry ``REWARD_MODEL_REGISTRY``  |
        2  | ``intrinsic_``      str      add          | the intrinsic reward type              | including add, new
           | ``reward_type``                           |                                        | , or assign
        3  | ``learning_rate``   float    0.001        | The step size of gradient descent      |
        4  | ``batch_size``      int      64           | Training batch size                    |
        5  | ``hidden``          list     [64, 64,     | the MLP layer shape                    |
           | ``_size_list``      (int)    128]         |                                        |
        6  | ``update_per_``     int      100          | Number of updates per collect          |
           | ``collect``                               |                                        |
        7  | ``input_norm``        bool     True         | Observation normalization              |
        8  | ``input_norm_``       int      0            | min clip value for obs normalization   |
           | ``clamp_min``
        9  | ``input_norm_``       int      1            | max clip value for obs normalization   |
           | ``clamp_max``
        10 | ``intrinsic_``      float    0.01         | the weight of intrinsic reward         | r = w*r_i + r_e
             ``reward_weight``
        11 | ``extrinsic_``      bool     True         | Whether to normlize extrinsic reward
             ``reward_norm``
        12 | ``extrinsic_``      int      1            | the upper bound of the reward
            ``reward_norm_max``                        | normalization
        == ====================  =====  =============  =======================================  =======================
    """
    config = dict(
        # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``.
        type='rnd',
        # (str) The intrinsic reward type, including add, new, or assign.
        intrinsic_reward_type='add',
        # (float) The step size of gradient descent.
        learning_rate=1e-3,
        # (float) Batch size.
        batch_size=64,
        # (list(int)) Sequence of ``hidden_size`` of reward network.
        # If obs.shape == 1,  use MLP layers.
        # If obs.shape == 3,  use conv layer and final dense layer.
        hidden_size_list=[64, 64, 128],
        # (int) How many updates(iterations) to train after collector's one collection.
        # Bigger "update_per_collect" means bigger off-policy.
        # collect data -> update policy-> collect data -> ...
        update_per_collect=100,
        # (bool) Observation normalization: transform obs to mean 0, std 1.
        input_norm=True,
        # (int) Min clip value for observation normalization.
        input_norm_clamp_min=-1,
        # (int) Max clip value for observation normalization.
        input_norm_clamp_max=1,
        # Means the relative weight of RND intrinsic_reward.
        # (float) The weight of intrinsic reward
        # r = intrinsic_reward_weight * r_i + r_e.
        intrinsic_reward_weight=0.01,
        # (bool) Whether to normalize extrinsic reward.
        # Normalize the reward to [0, extrinsic_reward_norm_max].
        extrinsic_reward_norm=True,
        # (int) The upper bound of the reward normalization.
        extrinsic_reward_norm_max=1,
    )

    def __init__(self, config: EasyDict, device: str = 'cpu', tb_logger: 'SummaryWriter' = None,
                 representation_network: nn.Module = None, target_representation_network: nn.Module = None,
                 use_momentum_representation_network: bool = True) -> None:  # noqa
        super(RNDRewardModel, self).__init__()
        self.cfg = config
        self.representation_network = representation_network
        self.target_representation_network = target_representation_network
        self.use_momentum_representation_network = use_momentum_representation_network
        self.input_type = self.cfg.input_type
        assert self.input_type in ['obs', 'latent_state', 'obs_latent_state'], self.input_type
        self.device = device
        assert self.device == "cpu" or self.device.startswith("cuda")
        self.rnd_buffer_size = config.rnd_buffer_size
        self.intrinsic_reward_type = self.cfg.intrinsic_reward_type
        if tb_logger is None:
            from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter('rnd_reward_model')
        self.tb_logger = tb_logger
        if self.input_type == 'obs':
            self.input_shape = self.cfg.obs_shape
            self.reward_model = RNDNetwork(self.input_shape, self.cfg.hidden_size_list).to(self.device)
        elif self.input_type == 'latent_state':
            self.input_shape = self.cfg.latent_state_dim
            self.reward_model = RNDNetwork(self.input_shape, self.cfg.hidden_size_list).to(self.device)
        elif self.input_type == 'obs_latent_state':
            if self.use_momentum_representation_network:
                self.reward_model = RNDNetworkRepr(self.cfg.obs_shape, self.cfg.latent_state_dim, self.cfg.hidden_size_list[0:-1],
                                                  self.target_representation_network).to(self.device)
            else:
                self.reward_model = RNDNetworkRepr(self.cfg.obs_shape, self.cfg.latent_state_dim, self.cfg.hidden_size_list[0:-1],
                                                  self.representation_network).to(self.device)

        assert self.intrinsic_reward_type in ['add', 'new', 'assign']
        if self.input_type in ['obs', 'obs_latent_state']:
            self.train_obs = []
        if self.input_type == 'latent_state':
            self.train_latent_state = []

        self._optimizer_rnd = torch.optim.Adam(
            self.reward_model.predictor.parameters(), lr=self.cfg.learning_rate, weight_decay=self.cfg.weight_decay
        )

        self._running_mean_std_rnd_reward = RunningMeanStd(epsilon=1e-4)
        self._running_mean_std_rnd_obs = RunningMeanStd(epsilon=1e-4)
        self.estimate_cnt_rnd = 0
        self.train_cnt_rnd = 0

    def _train_with_data_one_step(self) -> None:
        if self.input_type in ['obs', 'obs_latent_state']:
            train_data = random.sample(self.train_obs, self.cfg.batch_size)
        elif self.input_type == 'latent_state':
            train_data = random.sample(self.train_latent_state, self.cfg.batch_size)

        train_data = torch.stack(train_data).to(self.device)

        if self.cfg.input_norm:
            # Note: observation normalization: transform obs to mean 0, std 1
            self._running_mean_std_rnd_obs.update(train_data.detach().cpu().numpy())
            normalized_train_data = (train_data - to_tensor(self._running_mean_std_rnd_obs.mean).to(
                self.device)) / to_tensor(
                self._running_mean_std_rnd_obs.std
            ).to(self.device)
            train_data = torch.clamp(normalized_train_data, min=self.cfg.input_norm_clamp_min,
                                     max=self.cfg.input_norm_clamp_max)

        predict_feature, target_feature = self.reward_model(train_data)
        loss = F.mse_loss(predict_feature, target_feature)

        self.tb_logger.add_scalar('rnd_reward_model/rnd_mse_loss', loss, self.train_cnt_rnd)
        self._optimizer_rnd.zero_grad()
        loss.backward()
        self._optimizer_rnd.step()

    def train_with_data(self) -> None:
        for _ in range(self.cfg.update_per_collect):
            # for name, param in self.reward_model.named_parameters():
            #     if param.grad is not None:
            #         print(f"{name}: {torch.isnan(param.grad).any()}, {torch.isinf(param.grad).any()}")
            #         print(f"{name}: grad min: {param.grad.min()}, grad max: {param.grad.max()}")
            # # enable the following line to check whether there is nan or inf in the gradient.
            # torch.autograd.set_detect_anomaly(True)
            self._train_with_data_one_step()
            self.train_cnt_rnd += 1

    def estimate(self, data: list) -> List[Dict]:
        """
        Rewrite the reward key in each row of the data.
        """
        # current_batch, target_batch = data
        # obs_batch_orig, action_batch, mask_batch, indices, weights, make_time = current_batch
        # target_reward, target_value, target_policy = target_batch
        obs_batch_orig = data[0][0]
        target_reward = data[1][0]
        batch_size = obs_batch_orig.shape[0]
        # reshape to (4, 2835, 6)
        obs_batch_tmp = np.reshape(obs_batch_orig, (batch_size, self.cfg.obs_shape, 6))
        # reshape to (24, 2835)
        obs_batch_tmp = np.reshape(obs_batch_tmp, (batch_size * 6, self.cfg.obs_shape))

        if self.input_type == 'latent_state':
            with torch.no_grad():
                latent_state = self.representation_network(torch.from_numpy(obs_batch_tmp).to(self.device))
            input_data = latent_state
        elif self.input_type in ['obs', 'obs_latent_state']:
            input_data = to_tensor(obs_batch_tmp).to(self.device)

        # NOTE: deepcopy reward part of data is very important,
        # otherwise the reward of data in the replay buffer will be incorrectly modified.
        target_reward_augmented = copy.deepcopy(target_reward)
        target_reward_augmented = np.reshape(target_reward_augmented, (batch_size * 6, 1))

        if self.cfg.input_norm:
            # add this line to avoid inplace operation on the original tensor.
            input_data = input_data.clone()
            # Note: observation normalization: transform obs to mean 0, std 1
            input_data = (input_data - to_tensor(self._running_mean_std_rnd_obs.mean
                                                 ).to(self.device)) / to_tensor(self._running_mean_std_rnd_obs.std).to(
                self.device)
            input_data = torch.clamp(input_data, min=self.cfg.input_norm_clamp_min, max=self.cfg.input_norm_clamp_max)
        else:
            input_data = input_data
        with torch.no_grad():
            predict_feature, target_feature = self.reward_model(input_data)
            mse = F.mse_loss(predict_feature, target_feature, reduction='none').mean(dim=1)
            self._running_mean_std_rnd_reward.update(mse.detach().cpu().numpy())

            # Note: according to the min-max normalization, transform rnd reward to [0,1]
            rnd_reward = (mse - mse.min()) / (mse.max() - mse.min() + 1e-6)

            # save the rnd_reward statistics into tb_logger
            self.estimate_cnt_rnd += 1
            self.tb_logger.add_scalar('rnd_reward_model/rnd_reward_max', rnd_reward.max(), self.estimate_cnt_rnd)
            self.tb_logger.add_scalar('rnd_reward_model/rnd_reward_mean', rnd_reward.mean(), self.estimate_cnt_rnd)
            self.tb_logger.add_scalar('rnd_reward_model/rnd_reward_min', rnd_reward.min(), self.estimate_cnt_rnd)
            self.tb_logger.add_scalar('rnd_reward_model/rnd_reward_std', rnd_reward.std(), self.estimate_cnt_rnd)

        rnd_reward = rnd_reward.to(self.device).unsqueeze(1).cpu().numpy()
        if self.intrinsic_reward_type == 'add':
            if self.cfg.extrinsic_reward_norm:
                target_reward_augmented = target_reward_augmented / self.cfg.extrinsic_reward_norm_max + rnd_reward * self.cfg.intrinsic_reward_weight
            else:
                target_reward_augmented = target_reward_augmented + rnd_reward * self.cfg.intrinsic_reward_weight
        elif self.intrinsic_reward_type == 'new':
            if self.cfg.extrinsic_reward_norm:
                target_reward_augmented = target_reward_augmented / self.cfg.extrinsic_reward_norm_max
        elif self.intrinsic_reward_type == 'assign':
            target_reward_augmented = rnd_reward

        self.tb_logger.add_scalar('augmented_reward/reward_max', np.max(target_reward_augmented), self.estimate_cnt_rnd)
        self.tb_logger.add_scalar('augmented_reward/reward_mean', np.mean(target_reward_augmented),
                                  self.estimate_cnt_rnd)
        self.tb_logger.add_scalar('augmented_reward/reward_min', np.min(target_reward_augmented), self.estimate_cnt_rnd)
        self.tb_logger.add_scalar('augmented_reward/reward_std', np.std(target_reward_augmented), self.estimate_cnt_rnd)

        # reshape to (target_reward_augmented.shape[0], 6, 1)
        target_reward_augmented = np.reshape(target_reward_augmented, (batch_size, 6, 1))
        data[1][0] = target_reward_augmented
        train_data_augmented = data

        return train_data_augmented

    def collect_data(self, data: list) -> None:
        # TODO(pu): now we only collect the first 300 steps of each game segment.
        collected_transitions = np.concatenate([game_segment.obs_segment[:300] for game_segment in data[0]], axis=0)
        if self.input_type == 'latent_state':
            with torch.no_grad():
                self.train_latent_state.extend(
                    self.representation_network(torch.from_numpy(collected_transitions).to(self.device)))
        elif self.input_type == 'obs':
            self.train_obs.extend(to_tensor(collected_transitions).to(self.device))
        elif self.input_type == 'obs_latent_state':
            self.train_obs.extend(to_tensor(collected_transitions).to(self.device))

    def clear_old_data(self) -> None:
        if self.input_type == 'latent_state':
            if len(self.train_latent_state) >= self.cfg.rnd_buffer_size:
                self.train_latent_state = self.train_latent_state[-self.cfg.rnd_buffer_size:]
        elif self.input_type == 'obs':
            if len(self.train_obs) >= self.cfg.rnd_buffer_size:
                self.train_obs = self.train_obs[-self.cfg.rnd_buffer_size:]
        elif self.input_type == 'obs_latent_state':
            if len(self.train_obs) >= self.cfg.rnd_buffer_size:
                self.train_obs = self.train_obs[-self.cfg.rnd_buffer_size:]

    def state_dict(self) -> Dict:
        return self.reward_model.state_dict()

    def load_state_dict(self, _state_dict: Dict) -> None:
        self.reward_model.load_state_dict(_state_dict)

    def clear_data(self):
        pass

    def train(self):
        pass
