#######################################################################
# Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com)    #
# Permission given to modify the code as long as you keep this        #
# declaration at the top                                              #
#######################################################################

from ..network import *
from ..component import *
from ..utils import *
import time
from .BaseAgent import *
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from mpl_toolkits.mplot3d import Axes3D
from collections import deque

class DQNActor(BaseActor):
    def __init__(self, config):
        BaseActor.__init__(self, config)
        self.config = config
        self.start()

    def compute_q(self, prediction):
        q_values = to_np(prediction['q'])
        return q_values

    def _transition(self):
        if self._state is None:
            self._state = self._task.reset()
        config = self.config
        if config.noisy_linear:
            self._network.reset_noise()
        with config.lock:
            prediction, _ = self._network(config.state_normalizer(self._state))
        q_values = self.compute_q(prediction)

        if config.noisy_linear:
            epsilon = 0
        elif self._total_steps < config.exploration_steps:
            epsilon = 1
        else:
            epsilon = config.random_action_prob()
        action = epsilon_greedy(epsilon, q_values)
        next_state, reward, done, info = self._task.step(action)
        entry = [self._state, action, reward, next_state, done, info]
        self._total_steps += 1
        self._state = next_state
        return entry


class DQNAgent(BaseAgent):
    def __init__(self, config):
        BaseAgent.__init__(self, config)
        self.config = config
        config.lock = mp.Lock()

        self.replay = config.replay_fn()
        self.actor = DQNActor(config)

        self.network = config.network_fn()
        self.network.share_memory()
        self.target_network = config.network_fn()
        self.target_network.load_state_dict(self.network.state_dict())
        self.optimizer = config.optimizer_fn(self.network.parameters())

        self.actor.set_network(self.network)
        self.total_steps = 0

    def close(self):
        close_obj(self.replay)
        close_obj(self.actor)

    def eval_step(self, state):
        self.config.state_normalizer.set_read_only()
        state = self.config.state_normalizer(state)
        # q = self.network(state)['q']
        q_, feature = self.network(state)
        q = q_['q']

        action = to_np(q.argmax(-1))
        self.config.state_normalizer.unset_read_only()
        return action, feature

    def reduce_loss(self, loss):
        return loss.pow(2).mul(0.5).mean()

    def compute_loss(self, transitions):
        config = self.config
        states = self.config.state_normalizer(transitions.state)
        next_states = self.config.state_normalizer(transitions.next_state)
        with torch.no_grad():
            q_, _ = self.target_network(next_states)
            q_next = q_['q'].detach()
            if self.config.double_q:
                q_, _ = self.network(next_states)
                best_actions = torch.argmax(q_['q'], dim=-1)
                q_next = q_next.gather(1, best_actions.unsqueeze(-1)).squeeze(1)
            else:
                q_next = q_next.max(1)[0]
        masks = tensor(transitions.mask)
        rewards = tensor(transitions.reward)
        q_target = rewards + self.config.discount ** config.n_step * q_next * masks
        actions = tensor(transitions.action).long()
        q_, _ = self.network(states)
        q = q_['q']
        q = q.gather(1, actions.unsqueeze(-1)).squeeze(-1)
        loss = q_target - q
        return loss


    def TSNE(self):
        # NUMBER_ENV = self.config.batch_size # 32
        n_labels = self.config.action_dim
        # n_labels = 3 # to be changed !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        feature_dict = {}
        state_dict = {}
        # data_num = 512 * int((40) / n_labels)
        data_num = 512


        def current_num():
            num = [0] * n_labels
            for i in range(n_labels):
                num[i] = len(feature_dict[i])
            return num
        env = self.config.eval_env
        state = env.reset()

        self.config.state_normalizer.set_read_only()
        state0 = self.config.state_normalizer(state) # numpy, [1,4,84,84]
        # print(type(state0), state0.shape, state0)
        temp = torch.ones_like(torch.from_numpy(state0))
        for i in range(n_labels):
            feature_dict[i] = np.ones((1, 512))
            state_dict[i] = temp

            # collect data for each action class
        # data_num = 50 # to be changed !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        while min(current_num()) <= data_num:
            action, feature_vectors = self.eval_step(state) # batch = 1
            # print(action, feature_vectors.shape) # [1], [1, 512]
            # for i in range(NUMBER_ENV):
            # print(feature_vectors.shape, type(feature_vectors))
            if feature_dict[int(action)].shape[0] <= data_num: # extra 1
                feature_dict[int(action)] = np.vstack((feature_dict[int(action)], feature_vectors.detach().cpu().numpy()))
                # save state !!!!!!!!!!!!!!!!
                self.config.state_normalizer.set_read_only()
                state_ = self.config.state_normalizer(state)
                state_dict[int(action)] = torch.cat((state_dict[int(action)], torch.from_numpy(state_)), dim=0)

            # feature_dict[int(action)].append(feature_vectors)  # feature [j] <- current features [i]
            print('Current Action: {}, nums:{}/{}/{}/{}, min_num:{}/{}'.format(int(action), current_num()[0], current_num()[1],current_num()[2], current_num()[3], min(current_num()),data_num))
            # print('Current Action: {}, nums:{}/{}/{}, min_num:{}/{}'.format(int(action), current_num()[0], current_num()[1],current_num()[2], min(current_num()),data_num))
            # print('while condition', min(current_num()), data_num)
            # print('current shape', np.array(feature_dict[int(action)]).shape)
            state, reward, done, info = env.step(action)
            ret = info[0]['episodic_return']
            if ret is not None:
                state = env.reset()

        for i in range(n_labels):
            feature_dict[i] = feature_dict[i][1:,:]
            state_dict[i] = state_dict[i][1:,:]
            print('feature shape', feature_dict[i].shape)
            print('state shape', state_dict[i].shape)
        # t-sne
        COLOR = ['red', 'blue', 'lime', 'yellow']
        def plot_embedding(result, label, title, filename):  # [1083,2] [1083]
            x_min, x_max = np.min(result, 0), np.max(result, 0)
            data = (result - x_min) / (x_max - x_min)  # [0-1] scale
            plt.figure()
            for i in range(data.shape[0]):  # 1083
                # plt.scatter(data[i, 0], data[i, 1], marker='o', color=plt.cm.Set1(label[i] / (1.0 * n_labels)))
                plt.scatter(data[i, 0], data[i, 1], marker='o', color=COLOR[int(labels[i])], alpha=0.5)
            plt.title(title)
            plt.savefig('figure/' + filename + '.png', bbox_inches='tight')
            plt.close()

        def plot_embedding_3D(result, label, title, filename):
            x_min, x_max = np.min(result, 0), np.max(result, 0)
            data = (result - x_min) / (x_max - x_min)  # [0-1] scale
            fig = plt.figure()
            ax = Axes3D(fig)
            for i in range(data.shape[0]):
                # ax.scatter(data[i, 0], data[i, 1], data[i, 2], color=plt.cm.Set1(label[i] / (1.0 * n_labels)))
                ax.scatter(data[i, 0], data[i, 1], data[i, 2], color=COLOR[int(labels[i])], alpha=0.5)
            plt.title(title)
            plt.savefig('figure/' + filename + '.png', bbox_inches='tight')
            plt.close()

        if not os.path.exists('figure'):
            os.makedirs('figure')

        ############## save data: state tensor in each action file
        for i in range(n_labels):
            torch.save(state_dict[i], 'figure/'+self.config.game_file+str(i)+'.pt')


        ############
        labels = []
        data = np.ones((1, 512))
        samples = 200
        CLASSES = 2
        for i in range(CLASSES):  # only the first two labels
            NUMS_index = np.random.choice(data_num, samples, replace=False)
            data0 = feature_dict[i][NUMS_index, :]
            # data0 = np.array(feature_dict[i])[NUMS_index]
            # print(data0.shape)
            data = np.vstack((data, data0))
            labels.extend(np.ones(samples) * i)
        data = data[1:, :]
        print('Final data shape: ', data.shape)
        tsne2 = TSNE(n_components=2, init='pca', random_state=0)  # n_components: 64 -> 2；
        result2 = tsne2.fit_transform(data)
        tsne3 = TSNE(n_components=3, init='pca', random_state=0)  # n_components: 64 -> 2；
        result3 = tsne3.fit_transform(data)
        plot_embedding(result2, labels, 't-SNE on DQN Features', 'DQN-tSNE-2D')
        plot_embedding_3D(result3, labels, 't-SNE on DQN Features', 'DQN-tSNE-3D')

    # def TSNE(self):
    #     n_labels = self.config.action_dim
    #     feature_dict = {}
    #     state_dict = {}
    #     data_num = 512
    #     for i in range(n_labels):
    #         feature_dict[i] = np.ones((1, 512))
    #     for i in range(n_labels):
    #         state_dict[i] = torch.load('figure/'+self.config.game_file+str(i)+'.pt', map_location='cpu')
    #         for j in range(state_dict[i].shape[0]):
    #             # print('state.shape', state_dict[i].shape)
    #             if torch.cuda.is_available():
    #                 device = torch.device('cuda')
    #                 state = state_dict[i][j,:].unsqueeze(0).float().to(device) # [4,84,84] -> [1,4,84,84]
    #             else:
    #                 state = state_dict[i][j,:].unsqueeze(0).float() # [4,84,84] -> [1,4,84,84]
    #             action, feature_vectors = self.eval_step(state)
    #             # print(action)
    #             feature_dict[i] = np.vstack((feature_dict[i], feature_vectors.detach().cpu().numpy()))
    #             # feature_dict[int(action[0])] = np.vstack((feature_dict[int(action[0])], feature_vectors.detach().cpu().numpy()))
    #             # feature_dict[int(action)] = np.vstack((feature_dict[int(action)], feature_vectors.detach().cpu().numpy()))
    #
    #     for i in range(n_labels):
    #         feature_dict[i] = feature_dict[i][1:,:]
    #         print(feature_dict[i].shape)
    #
    #     # t-sne
    #     def plot_embedding(result, label, title, filename):  # [1083,2] [1083]
    #         x_min, x_max = np.min(result, 0), np.max(result, 0)
    #         data = (result - x_min) / (x_max - x_min)  # [0-1] scale
    #         plt.figure()
    #         for i in range(data.shape[0]):  # 1083
    #             plt.scatter(data[i, 0], data[i, 1], marker='o', color=plt.cm.Set1(label[i] / (1.0 * n_labels)))
    #         plt.title(title)
    #         plt.savefig('figure/' + filename + '.png', bbox_inches='tight')
    #         plt.close()
    #
    #     def plot_embedding_3D(result, label, title, filename):
    #         x_min, x_max = np.min(result, 0), np.max(result, 0)
    #         data = (result - x_min) / (x_max - x_min)  # [0-1] scale
    #         fig = plt.figure()
    #         ax = Axes3D(fig)
    #         for i in range(data.shape[0]):
    #             ax.scatter(data[i, 0], data[i, 1], data[i, 2], color=plt.cm.Set1(label[i] / (1.0 * n_labels)))
    #         plt.title(title)
    #         plt.savefig('figure/' + filename + '.png', bbox_inches='tight')
    #         plt.close()
    #
    #     if not os.path.exists('figure'):
    #         os.makedirs('figure')
    #
    #     labels = []
    #     data = np.ones((1, 512))
    #     samples = 200
    #     CLASSES = 3 # 4
    #     for i in range(CLASSES):  # only the first two labels
    #         NUMS_index = np.random.choice(data_num, samples, replace=False)
    #         data0 = feature_dict[i][NUMS_index, :]
    #         # data0 = np.array(feature_dict[i])[NUMS_index]
    #         # print(data0.shape)
    #         data = np.vstack((data, data0))
    #         labels.extend(np.ones(samples) * i)
    #     data = data[1:, :]
    #     print('Final data shape: ', data.shape)
    #     tsne2 = TSNE(n_components=2, init='pca', random_state=0)  # n_components: 64 -> 2；
    #     result2 = tsne2.fit_transform(data)
    #     tsne3 = TSNE(n_components=3, init='pca', random_state=0)  # n_components: 64 -> 2；
    #     result3 = tsne3.fit_transform(data)
    #     plot_embedding(result2, labels, 't-SNE on DQN Features', 'DQN-tSNE-2D')
    #     plot_embedding_3D(result3, labels, 't-SNE on DQN Features', 'DQN-tSNE-3D')

    def step(self):
        config = self.config
        transitions = self.actor.step()
        for states, actions, rewards, next_states, dones, info in transitions:
            self.record_online_return(info)
            self.total_steps += 1
            self.replay.feed(dict(
                state=np.array([s[-1] if isinstance(s, LazyFrames) else s for s in states]),
                action=actions,
                reward=[config.reward_normalizer(r) for r in rewards],
                mask=1 - np.asarray(dones, dtype=np.int32),
            ))

        if self.total_steps > self.config.exploration_steps:
            transitions = self.replay.sample()
            if config.noisy_linear:
                self.target_network.reset_noise()
                self.network.reset_noise()
            loss = self.compute_loss(transitions)
            if isinstance(transitions, PrioritizedTransition):
                priorities = loss.abs().add(config.replay_eps).pow(config.replay_alpha)
                idxs = tensor(transitions.idx).long()
                self.replay.update_priorities(zip(to_np(idxs), to_np(priorities)))
                sampling_probs = tensor(transitions.sampling_prob)
                weights = sampling_probs.mul(sampling_probs.size(0)).add(1e-6).pow(-config.replay_beta())
                weights = weights / weights.max()
                loss = loss.mul(weights)

            loss = self.reduce_loss(loss)
            self.optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(self.network.parameters(), self.config.gradient_clip)
            with config.lock:
                self.optimizer.step() # quantile network

            # another loss for net_pi

            # optimizer2 = optim.Adam(net_pi.parameters(), lr=0.001)
            # net_pi.para = net_pi.para + 0.1 * net_pi.grad
            # optimizer2.step()

        if self.total_steps / self.config.sgd_update_frequency % \
                self.config.target_network_update_freq == 0:
            self.target_network.load_state_dict(self.network.state_dict())
