#######################################################################
# Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com)    #
# Permission given to modify the code as long as you keep this        #
# declaration at the top                                              #
#######################################################################

from ..network import *
from ..component import *
from ..utils import *
from .BaseAgent import *
from .DQN_agent import *


class QuantileRegressionDQNActor(DQNActor):
    def __init__(self, config):
        super().__init__(config)

    def compute_q(self, prediction):
        q_values = prediction['quantile'].mean(-1)
        return to_np(q_values)


class QuantileRegressionDQNAgent(DQNAgent):
    def __init__(self, config):
        BaseAgent.__init__(self, config)
        self.config = config
        config.lock = mp.Lock()

        self.replay = config.replay_fn()
        self.actor = QuantileRegressionDQNActor(config)

        self.network = config.network_fn()
        self.network.share_memory()
        self.target_network = config.network_fn()
        self.target_network.load_state_dict(self.network.state_dict())
        self.optimizer = config.optimizer_fn(self.network.parameters())

        self.actor.set_network(self.network)

        self.total_steps = 0
        self.batch_indices = range_tensor(config.batch_size)

        self.quantile_weight = 1.0 / self.config.num_quantiles
        self.cumulative_density = tensor(
            (2 * np.arange(self.config.num_quantiles) + 1) / (2.0 * self.config.num_quantiles)).view(1, -1)

    def eval_step(self, state):
        self.config.state_normalizer.set_read_only()
        state = self.config.state_normalizer(state)
        # q= self.network(state)['quantile'].mean(-1)
        # NEW
        q_, feature = self.network(state)
        q = q_['quantile'].mean(-1)

        action = np.argmax(to_np(q).flatten())
        self.config.state_normalizer.unset_read_only()
        return [action], feature
        # print('action', action)
        # return action, feature

    def compute_loss(self, transitions):
        states = self.config.state_normalizer(transitions.state)
        next_states = self.config.state_normalizer(transitions.next_state)

        # new
        q_, _ = self.target_network(next_states)
        quantiles_next = q_['quantile'].detach()
        a_next = torch.argmax(quantiles_next.sum(-1), dim=-1)
        quantiles_next = quantiles_next[self.batch_indices, a_next, :]

        rewards = tensor(transitions.reward).unsqueeze(-1)
        masks = tensor(transitions.mask).unsqueeze(-1)
        quantiles_next = rewards + self.config.discount ** self.config.n_step * masks * quantiles_next

        # new
        q_, _ = self.network(states)
        quantiles = q_['quantile']


        actions = tensor(transitions.action).long()
        quantiles = quantiles[self.batch_indices, actions, :]

        quantiles_next = quantiles_next.t().unsqueeze(-1)
        diff = quantiles_next - quantiles
        loss = huber(diff) * (self.cumulative_density - (diff.detach() < 0).float()).abs()
        return loss.sum(-1).mean(1)

    def reduce_loss(self, loss):
        return loss.mean()

    # def TSNE(self):
    #     # NUMBER_ENV = self.config.batch_size # 32
    #     n_labels = self.config.action_dim
    #     # n_labels = 3 # to be changed !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    #     feature_dict = {}
    #     state_dict = {}
    #     # data_num = 512 * int((40) / n_labels)
    #     data_num = 512
    #
    #     def current_num():
    #         num = [0] * n_labels
    #         for i in range(n_labels):
    #             num[i] = len(feature_dict[i])
    #         return num
    #
    #     env = self.config.eval_env
    #     state = env.reset()
    #
    #     self.config.state_normalizer.set_read_only()
    #     state0 = self.config.state_normalizer(state)  # numpy, [1,4,84,84]
    #     # print(type(state0), state0.shape, state0)
    #     temp = torch.ones_like(torch.from_numpy(state0))
    #     for i in range(n_labels):
    #         feature_dict[i] = np.ones((1, 512))
    #         state_dict[i] = temp
    #
    #         # collect data for each action class
    #     # data_num = 50 # to be changed !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    #     while min(current_num()) <= data_num:
    #         action, feature_vectors = self.eval_step(state)  # batch = 1
    #         # print(action, feature_vectors.shape) # [1], [1, 512]
    #         # for i in range(NUMBER_ENV):
    #         # print(feature_vectors.shape, type(feature_vectors))
    #         if feature_dict[int(action[0])].shape[0] <= data_num:  # extra 1
    #             feature_dict[int(action[0])] = np.vstack((feature_dict[int(action[0])], feature_vectors.detach().cpu().numpy()))
    #             # save state !!!!!!!!!!!!!!!!
    #             self.config.state_normalizer.set_read_only()
    #             state_ = self.config.state_normalizer(state)
    #             # state_dict[int(action)] = torch.cat((state_dict[int(action)], torch.from_numpy(state_)), dim=0)
    #             state_dict[int(action[0])] = torch.cat((state_dict[int(action[0])], torch.from_numpy(state_)), dim=0)
    #
    #         # feature_dict[int(action)].append(feature_vectors)  # feature [j] <- current features [i]
    #         # print('Current Action: {}, nums:{}/{}/{}/{}, min_num:{}/{}'.format(int(action), current_num()[0],
    #         print('Current Action: {}, nums:{}/{}/{}/{}, min_num:{}/{}'.format(int(action[0]), current_num()[0],
    #                                                                            current_num()[1], current_num()[2],
    #                                                                            current_num()[3], min(current_num()),
    #                                                                            data_num))
    #         # print('Current Action: {}, nums:{}/{}/{}, min_num:{}/{}'.format(int(action), current_num()[0], current_num()[1],current_num()[2], min(current_num()),data_num))
    #         # print('while condition', min(current_num()), data_num)
    #         # print('current shape', np.array(feature_dict[int(action)]).shape)
    #         state, reward, done, info = env.step(action)
    #         ret = info[0]['episodic_return']
    #         if ret is not None:
    #             state = env.reset()
    #
    #     for i in range(n_labels):
    #         feature_dict[i] = feature_dict[i][1:, :]
    #         state_dict[i] = state_dict[i][1:, :]
    #         print('feature shape', feature_dict[i].shape)
    #         print('state shape', state_dict[i].shape)
    #
    #     # t-sne
    #     def plot_embedding(result, label, title, filename):  # [1083,2] [1083]
    #         x_min, x_max = np.min(result, 0), np.max(result, 0)
    #         data = (result - x_min) / (x_max - x_min)  # [0-1] scale
    #         plt.figure()
    #         for i in range(data.shape[0]):  # 1083
    #             plt.scatter(data[i, 0], data[i, 1], marker='o', color=plt.cm.Set1(label[i] / (1.0 * n_labels)))
    #         plt.title(title)
    #         plt.savefig('figure/' + filename + '.png', bbox_inches='tight')
    #         plt.close()
    #
    #     def plot_embedding_3D(result, label, title, filename):
    #         x_min, x_max = np.min(result, 0), np.max(result, 0)
    #         data = (result - x_min) / (x_max - x_min)  # [0-1] scale
    #         fig = plt.figure()
    #         ax = Axes3D(fig)
    #         for i in range(data.shape[0]):
    #             ax.scatter(data[i, 0], data[i, 1], data[i, 2], color=plt.cm.Set1(label[i] / (1.0 * n_labels)))
    #         plt.title(title)
    #         plt.savefig('figure/' + filename + '.png', bbox_inches='tight')
    #         plt.close()
    #
    #     if not os.path.exists('figure'):
    #         os.makedirs('figure')
    #
    #     ############## save data: state tensor in each action file
    #     for i in range(n_labels):
    #         torch.save(state_dict[i], 'figure/' + self.config.game_file + str(i) + '.pt')
    #
    #     ############
    #     labels = []
    #     data = np.ones((1, 512))
    #     samples = 200
    #     CLASSES = 3
    #     for i in range(CLASSES):  # only the first two labels
    #         NUMS_index = np.random.choice(data_num, samples, replace=False)
    #         data0 = feature_dict[i][NUMS_index, :]
    #         # data0 = np.array(feature_dict[i])[NUMS_index]
    #         # print(data0.shape)
    #         data = np.vstack((data, data0))
    #         labels.extend(np.ones(samples) * i)
    #     data = data[1:, :]
    #     print('Final data shape: ', data.shape)
    #     tsne2 = TSNE(n_components=2, init='pca', random_state=0)  # n_components: 64 -> 2；
    #     result2 = tsne2.fit_transform(data)
    #     tsne3 = TSNE(n_components=3, init='pca', random_state=0)  # n_components: 64 -> 2；
    #     result3 = tsne3.fit_transform(data)
    #     plot_embedding(result2, labels, 't-SNE on QRDQN Features', 'QRDQN-tSNE-2D')
    #     plot_embedding_3D(result3, labels, 't-SNE on QRDQN Features', 'QRDQN-tSNE-3D')

    def TSNE(self):
        n_labels = self.config.action_dim
        feature_dict = {}
        state_dict = {}
        data_num = 512
        for i in range(n_labels):
            feature_dict[i] = np.ones((1, 512))
        for i in range(n_labels):
            state_dict[i] = torch.load('figure/'+self.config.game_file+str(i)+'.pt', map_location='cpu')
            for j in range(state_dict[i].shape[0]):
                if torch.cuda.is_available():
                    device = torch.device('cuda')
                    state = state_dict[i][j,:].unsqueeze(0).float().to(device) # [4,84,84] -> [1,4,84,84]
                else:
                    state = state_dict[i][j,:].unsqueeze(0).float() # [4,84,84] -> [1,4,84,84]
                action, feature_vectors = self.eval_step(state)
                feature_dict[i] = np.vstack((feature_dict[i], feature_vectors.detach().cpu().numpy()))
                # feature_dict[int(action[0])] = np.vstack((feature_dict[int(action[0])], feature_vectors.detach().cpu().numpy()))

        for i in range(n_labels):
            feature_dict[i] = feature_dict[i][1:,:]
            print(feature_dict[i].shape)

        # t-sne
        COLOR = ['red', 'blue', 'lime', 'yellow']
        def plot_embedding(result, label, title, filename):  # [1083,2] [1083]

            x_min, x_max = np.min(result, 0), np.max(result, 0)
            data = (result - x_min) / (x_max - x_min)  # [0-1] scale
            plt.figure()
            for i in range(data.shape[0]):  # 1083
                # plt.scatter(data[i, 0], data[i, 1], marker='o', color=plt.cm.Set1(label[i] / (1.0 * n_labels)))
                plt.scatter(data[i, 0], data[i, 1], marker='o', color=COLOR[int(labels[i])], alpha=0.5)
            plt.title(title)
            plt.savefig('figure/' + filename + '.png', bbox_inches='tight')
            plt.close()

        def plot_embedding_3D(result, label, title, filename):
            x_min, x_max = np.min(result, 0), np.max(result, 0)
            data = (result - x_min) / (x_max - x_min)  # [0-1] scale
            fig = plt.figure()
            ax = Axes3D(fig)
            for i in range(data.shape[0]):
                # ax.scatter(data[i, 0], data[i, 1], data[i, 2], color=plt.cm.Set1(label[i] / (1.0 * n_labels)))
                ax.scatter(data[i, 0], data[i, 1], data[i, 2],  color=COLOR[int(labels[i])], alpha=0.5)
            plt.title(title)
            plt.savefig('figure/' + filename + '.png', bbox_inches='tight')
            plt.close()

        if not os.path.exists('figure'):
            os.makedirs('figure')

        labels = []
        data = np.ones((1, 512))
        samples = 200
        CLASSES = 2 # 4
        for i in range(CLASSES):  # only the first two labels
            NUMS_index = np.random.choice(data_num, samples, replace=False)
            data0 = feature_dict[i][NUMS_index, :]
            # data0 = np.array(feature_dict[i])[NUMS_index]
            # print(data0.shape)
            data = np.vstack((data, data0))
            labels.extend(np.ones(samples) * i)
        data = data[1:, :]
        print('Final data shape: ', data.shape)
        tsne2 = TSNE(n_components=2, init='pca', random_state=0)  # n_components: 64 -> 2；
        result2 = tsne2.fit_transform(data)
        tsne3 = TSNE(n_components=3, init='pca', random_state=0)  # n_components: 64 -> 2；
        result3 = tsne3.fit_transform(data)
        plot_embedding(result2, labels, 't-SNE on QR-DQN Features', 'QRDQN-tSNE-2D')
        plot_embedding_3D(result3, labels, 't-SNE on QR-DQN Features', 'QRDQN-tSNE-3D')
