from gym.spaces import Discrete

from rlkit.data_management.simple_replay_buffer import SimpleReplayBuffer, CustomReplayBuffer
from rlkit.envs.env_utils import get_dim
import rlkit.torch.pytorch_util as ptu

import matplotlib.pyplot as plt

import numpy as np
import torch

from tqdm import tqdm
from collections import OrderedDict


class EnvReplayBuffer(SimpleReplayBuffer):
    def __init__(
            self,
            max_replay_buffer_size,
            env,
            env_info_sizes=None
    ):
        """exp_name
        :param max_replay_buffer_size:
        :param env:
        """
        self.env = env
        self._ob_space = env.observation_space
        self._action_space = env.action_space

        if env_info_sizes is None:
            if hasattr(env, 'info_sizes'):
                env_info_sizes = env.info_sizes
            else:
                env_info_sizes = dict()

        super().__init__(
            max_replay_buffer_size=max_replay_buffer_size,
            observation_dim=get_dim(self._ob_space),
            action_dim=get_dim(self._action_space),
            env_info_sizes=env_info_sizes
        )

    def add_sample(self, observation, action, reward, terminal,
                   next_observation, **kwargs):
        if isinstance(self._action_space, Discrete):
            new_action = np.zeros(self._action_dim)
            new_action[action] = 1
        else:
            new_action = action
        return super().add_sample(
            observation=observation,
            action=new_action,
            reward=reward,
            next_observation=next_observation,
            terminal=terminal,
            **kwargs
        )

class EnvReplayBuffer2(CustomReplayBuffer):
    def __init__(
            self,
            max_replay_buffer_size,
            env,
            env_info_sizes=None
    ):
        """exp_name
        :param max_replay_buffer_size:
        :param env:
        """
        self.env = env
        self._ob_space = env.observation_space
        self._action_space = env.action_space

        if env_info_sizes is None:
            if hasattr(env, 'info_sizes'):
                env_info_sizes = env.info_sizes
            else:
                env_info_sizes = dict()

        super().__init__(
            max_replay_buffer_size=max_replay_buffer_size,
            observation_dim=get_dim(self._ob_space),
            action_dim=get_dim(self._action_space),
            env_info_sizes=env_info_sizes
        )

    def add_sample(self, observation, action, reward, terminal,
                   next_observation, **kwargs):
        if isinstance(self._action_space, Discrete):
            new_action = np.zeros(self._action_dim)
            new_action[action] = 1
        else:
            new_action = action
        return super().add_sample(
            observation=observation,
            action=new_action,
            reward=reward,
            next_observation=next_observation,
            terminal=terminal,
            **kwargs
        )

    def calculate_weight(self, cluster_dist_list, cluster_idx_list, clip_val, path=None, ver=1, diff_bound=1.0,
                         behavior_policy=None, q_data=None):
        print('replay buffer ver: %d'%ver)
        # # (999000, 20) -> (999000)
        # d = cluster_dist_list.mean(axis=1)
        #
        # # (999000, 20, 17)
        # curr_obs_cls = self._observations[cluster_idx_list]
        #
        # n_criteria_vector = np.repeat(np.expand_dims(curr_obs_cls[:, 0, :], axis=1), 20, axis=1)
        # curr_distances = np.linalg.norm(curr_obs_cls - n_criteria_vector, axis=2)
        # curr_distances = np.mean(curr_distances[:, 1:], axis=1, keepdims=True)
        # print(curr_distances.mean())
        #
        # plt.xlabel("Batch Number")
        # plt.ylabel("curr_distances")
        # plt.plot(curr_distances[:2000], label='mean:%.3f, min:%.3f, max:%.3f' % (
        #     curr_distances.mean(), curr_distances.min(), curr_distances.max()))
        # plt.legend()
        # plt.savefig(f'{path}/curr_distances')
        # plt.close()

        s = self._observations[cluster_idx_list]
        a = self._actions[cluster_idx_list]

        # n_criteria_vector = np.repeat(np.expand_dims(next_obs_cls[:, 0, :], axis=1), 20, axis=1)
        # next_distances = np.linalg.norm(next_obs_cls - n_criteria_vector, axis=2)
        # next_distances = np.mean(next_distances[:, 1:], axis=1, keepdims=True)
        # print(next_distances.mean())

        # plt.xlabel("Batch Number")
        # plt.ylabel("next_distances")
        # plt.plot(next_distances[:2000], label='mean:%.3f, min:%.3f, max:%.3f' % (
        #     next_distances.mean(), next_distances.min(), next_distances.max()))
        # plt.legend()
        # plt.savefig(f'{path}/next_distances')
        # plt.close()
        #
        diff = np.zeros_like(self._rewards)
        #
        # # Q^beta (beta + noise)
        behavior_policy.cpu()
        q_data[0].cpu()
        #
        num_actions = 20
        for i in tqdm(range(0, self._size, 10000), desc='Calculate the probability'):
            end_idx = min(self._size, i + 10000)
            bat_s = ptu.from_numpy(s[i:end_idx]).cpu()
            s_temp = bat_s.view(bat_s.shape[0] * num_actions, bat_s.shape[-1])

            bat_a = ptu.from_numpy(a[i:end_idx]).cpu()
            a_temp = bat_a.view(bat_a.shape[0] * num_actions, bat_a.shape[-1])

            q_val = q_data[0](s_temp, a_temp).view(bat_s.shape[0], num_actions)

            diff[i:end_idx] = ptu.get_numpy(q_val)

        print(diff.mean())
        #
        # plt.xlabel("Batch Number")
        # plt.ylabel("diff")
        # plt.plot(diff[:2000], label='mean:%.3f, min:%.3f, max:%.3f' % (
        #     diff.mean(), diff.min(), diff.max()))
        # plt.legend()
        # plt.savefig(f'{path}/diff')
        # plt.close()
        #
        # print(diff.mean() / next_distances.mean())
        # # B x N x D -> B x N -> B
        # d_prime = np.mean(np.sqrt(((curr_obs_cls - next_obs_cls) ** 2).sum(axis=2)), axis=1)
        # # diff coef.를 늘리고 dist_Temp를 늘리는 방향으로

        # diff = np.clip(d_prime / (d+1e-10), a_min=0.0, a_max=clip_val)

        # if path is not None:
        #     mean = diff.mean()
        #     min = diff.min()
        #     max = diff.max()
        #
        #     plt.xlabel("Batch Number")
        #     plt.ylabel("diff")
        #     plt.plot(diff[:2000], label='mean:%.3f, min:%.3f, max:%.3f' % (mean, min, max))
        #     plt.legend()
        #     plt.savefig(f'{path}/diff')
        #     plt.close()

        # diff = np.expand_dims(diff, axis=1)

        # ubound, lbound = diff.mean()+diff_bound, diff.mean()-diff_bound
        # diff = np.clip(diff, a_min=lbound, a_max=ubound)

        # if ver==4:
        #     # weight net 없애고 state 기반으로 바꾸기
        #     ns_std = np.std(np.sqrt(((next_obs_cls - np.expand_dims(self._next_obs, axis=1)) ** 2).sum(axis=2)), axis=1)
        #
        #     # if path is not None:
        #     #     mean = ns_std.mean()
        #     #     min = ns_std.min()
        #     #     max = ns_std.max()
        #     #
        #     #     plt.xlabel("Batch Number")
        #     #     plt.ylabel("ns_std")
        #     #     plt.plot(ns_std[:2000], label='mean:%.3f, min:%.3f, max:%.3f' % (mean, min, max))
        #     #     plt.legend()
        #     #     plt.savefig(f'{path}/ns_std')
        #     #     plt.close()
        # elif ver==5:
        #     num_actions = 10
        #     diff = np.zeros_like(self._rewards)
        #
        #     # Q^beta (beta + noise)
        #     s = torch.from_numpy(self._observations).float()
        #     behavior_policy.cpu()
        #     q_data[0].cpu()
        #     q_data[1].cpu()
        #
        #     for i in tqdm(range(0, self._size, 10000), desc='Calculate the probability'):
        #         end_idx = min(self._size, i + 10000)
        #         bat_s = s[i:end_idx]
        #
        #         obs_temp = bat_s.unsqueeze(1).repeat(1, num_actions, 1).view(bat_s.shape[0] * num_actions, bat_s.shape[1])
        #         beta_actions, _, _, new_obs_log_pi, *_ = behavior_policy(
        #             obs_temp, reparameterize=False, return_log_prob=True,
        #         )
        #
        #         noised_actions = (beta_actions + torch.randn_like(beta_actions) * 1.0).clamp(min=-1.0, max=1.0)
        #
        #         noised_q1_val = q_data[0](obs_temp, noised_actions).view(bat_s.shape[0], num_actions)
        #         noised_q2_val = q_data[1](obs_temp, noised_actions).view(bat_s.shape[0], num_actions)
        #
        #         diff[i:end_idx] = ptu.get_numpy(((noised_q1_val - noised_q2_val) ** 2).mean(dim=1, keepdim=True).clamp(max=200))
        #
        #     # diff = np.clip(diff, a_min=0, a_max=ubound)
        #     # std = np.clip(std, a_min=lbound, a_max=ubound)
        #     print(diff.mean())
        #     if path is not None:
        #         plt.xlabel("Batch Number")
        #         plt.ylabel("diff")
        #         plt.plot(diff[:2000], label='mean:%.3f, min:%.3f, max:%.3f' % (diff.mean(), diff.min(), diff.max()))
        #         plt.legend()
        #         plt.savefig(f'{path}/diff')
        #         plt.close()
        # elif ver==6:
        #     num_actions = 10
        #     std = np.zeros_like(self._rewards)
        #
        #     # Q^beta (beta + noise)
        #     s = torch.from_numpy(self._observations).float()
        #     behavior_policy.cpu()
        #     q_data[0].cpu()
        #
        #     for i in tqdm(range(0, self._size, 10000), desc='Calculate the probability'):
        #         end_idx = min(self._size, i + 10000)
        #         bat_s = s[i:end_idx]
        #
        #         obs_temp = bat_s.unsqueeze(1).repeat(1, num_actions, 1).view(bat_s.shape[0] * num_actions, bat_s.shape[1])
        #         beta_actions, _, _, new_obs_log_pi, *_ = behavior_policy(
        #             obs_temp, reparameterize=False, return_log_prob=True,
        #         )
        #
        #         noised_actions = (beta_actions + torch.randn_like(beta_actions) * 0.5).clamp(min=-1.0, max=1.0)
        #
        #         noised_q1_val = q_data[0](obs_temp, noised_actions).view(bat_s.shape[0], num_actions)
        #
        #         std[i:end_idx] = ptu.get_numpy(noised_q1_val.std(dim=1, keepdim=True).clamp(max=20))
        #     # diff = np.clip(diff, a_min=0, a_max=ubound)
        #     # std = np.clip(std, a_min=lbound, a_max=ubound)
        #     if path is not None:
        #
        #         plt.xlabel("Batch Number")
        #         plt.ylabel("std")
        #         plt.plot(std[:2000], label='mean:%.3f, min:%.3f, max:%.3f' % (std.mean(), std.min(), std.max()))
        #         plt.legend()
        #         plt.savefig(f'{path}/std')
        #         plt.close()
        #
        #     behavior_policy.to(ptu.device)
        #     q_data[0].to(ptu.device)
        #     q_data[1].to(ptu.device)
        #
        # if ver==3:
        #     self._diff = np.ones_like(diff) * diff.mean()
        # elif ver==4:
        #     self._diff = np.ones_like(diff) * diff.mean()
        # else:
        #     self._diff = diff
        #
        # if ver == 5:
        #     return np.exp(std.mean() / 30).mean() / 680
        # elif ver == 6:
        #     return np.exp(std.mean() / 2).mean() / 6000
        #
        # del d
        # del d_prime
        # del diff