import os
import logging
import pickle
import numpy as np
import tensorflow as tf
from pprint import pformat
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
from sklearn.metrics import roc_auc_score, pairwise_distances
from collections import namedtuple, defaultdict
import tensorflow_probability as tfp
tfd = tfp.distributions

from agents.base import BasePolicy
from agents.memory import ReplayMemory
from agents.random_process import OrnsteinUhlenbeckProcess
from agents.networks import convnet, dense
from agents.utils import plot_dict, plot_prob
from utils.visualize import save_image, mosaic

class WolpDDPGPolicy(BasePolicy):
    def __init__(self, hps, env, model=None, detector=None):
        super().__init__(hps, env, model, detector)
        
        # build knn searching components
        H, W, C = self.hps.image_shape
        H_idx, W_idx = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')
        self.candidates = np.vstack([H_idx.reshape(-1), W_idx.reshape(-1)]).transpose()
        
        # build random process
        self.random_process = OrnsteinUhlenbeckProcess(self.hps.ou_theta, self.hps.ou_mu, self.hps.ou_sigma)
    
    def export_action(self, action):
        num_actions = self.model.num_actions
        env_action = ((action + 1) / 2 * (num_actions - 1)).astype(np.int)
        
        return env_action
    
    def import_action(self, env_action):
        num_actions = self.model.num_actions
        action = (env_action / (num_actions - 1) - 0.5) * 2
        
        return action
    
    def wolpertinger(self, state, mask, aux, avail, proto, knn, is_target):
        env_proto = self.export_action(proto) # [B, 1]
        proto_coord = self.candidates[env_proto.squeeze(axis=1)] # [B, 2]
        dist = pairwise_distances(proto_coord, self.candidates) # [B, K]
        dist = np.where(avail, dist, np.inf)
        wolp_env_action = np.argsort(dist, axis=1)[:,:knn] # [B, k]
        wolp_action = self.import_action(wolp_env_action)
        if knn == 1:
            wolp_env_action = wolp_env_action.squeeze(axis=1)
            wolp_action = wolp_action.squeeze(axis=1)
        else:
            batch_size = len(state)
            wolp_action = wolp_action.reshape(-1) # [B*k]
            state = np.repeat(np.expand_dims(state, axis=1), knn, axis=1).reshape([-1]+self.hps.image_shape)
            mask = np.repeat(np.expand_dims(mask, axis=1), knn, axis=1).reshape([-1]+self.hps.image_shape)
            aux = np.repeat(np.expand_dims(aux, axis=1), knn, axis=1).reshape([-1]+self.model.auxiliary_shape)
            if is_target:
                q = self.sess.run(self.Q_target,
                        feed_dict={self.state_next: state,
                                   self.mask_next: mask,
                                   self.auxiliary_next: aux,
                                   self.action_next: wolp_action})
            else:
                q = self.sess.run(self.Q,
                        feed_dict={self.state: state,
                                   self.mask: mask,
                                   self.auxiliary: aux,
                                   self.action: wolp_action})
            q = q.reshape([batch_size, knn])
            max_index = np.argmax(q, axis=1)
            wolp_action = wolp_action.reshape([batch_size, knn])
            wolp_action = wolp_action[np.arange(batch_size), max_index]
            wolp_env_action = wolp_env_action[np.arange(batch_size), max_index]
            
        return wolp_action, wolp_env_action

    def act(self, state, mask, aux, avail, eps=0.0, rand=False, is_target=False):
        if rand:
            env_proto = np.array([np.random.choice(np.where(a)[0]) for a in avail], dtype=np.int32)
            proto = self.import_action(np.expand_dims(env_proto, axis=1))
            action, env_action = self.wolpertinger(state, mask, aux, avail, proto, 1, False)
            return action, env_action
        
        num_actions = self.model.num_actions
        knn = int(num_actions * self.hps.knn_ratio)
        if is_target:
            proto = self.sess.run(self.proto_target,
                feed_dict={self.state_next: state,
                           self.mask_next: mask,
                           self.auxiliary_next: aux})
        else:
            proto = self.sess.run(self.proto,
                feed_dict={self.state: state,
                           self.mask: mask,
                           self.auxiliary: aux})
            proto += eps * self.random_process.sample(proto.shape)
            proto = np.clip(proto, -1., 1.)
        action, env_action = self.wolpertinger(state, mask, aux, avail, proto, knn, is_target)
            
        return action, env_action
    
    def _build_nets(self):
        num_actions = self.model.num_actions
        self.state = tf.placeholder(tf.float32, shape=[None]+self.hps.image_shape, name='state')
        self.mask = tf.placeholder(tf.float32, shape=[None]+self.hps.image_shape, name='mask')
        self.auxiliary = tf.placeholder(tf.float32, shape=[None]+self.model.auxiliary_shape, name='auxiliary')
        self.avail = tf.placeholder(tf.bool, shape=[None, num_actions], name='avail')
        self.action = tf.placeholder(tf.float32, shape=[None], name='action')
        self.state_next = tf.placeholder(tf.float32, shape=[None]+self.hps.image_shape, name='state_next')
        self.mask_next = tf.placeholder(tf.float32, shape=[None]+self.hps.image_shape, name='mask_next')
        self.auxiliary_next = tf.placeholder(tf.float32, shape=[None]+self.model.auxiliary_shape, name='auxiliary_next')
        self.avail_next = tf.placeholder(tf.bool, shape=[None, num_actions], name='avail_next')
        self.action_next = tf.placeholder(tf.float32, shape=[None], name='action_next')
        self.reward = tf.placeholder(tf.float32, shape=[None], name='reward')
        self.done = tf.placeholder(tf.float32, shape=[None], name='done')
        
        with tf.variable_scope('embedding'):
            # embedding of current state
            embed = tf.concat([self.state / 255., self.mask, self.auxiliary], axis=-1)
            embed = convnet(embed, self.hps.embed_layers, name='embed')
            # embedding of next state
            embed_next = tf.concat([self.state_next / 255., self.mask_next, self.auxiliary_next], axis=-1)
            embed_next = convnet(embed_next, self.hps.embed_layers, name='embed')
            # embedding network variables
            self.embed_vars = self.scope_vars('embedding')
        
        with tf.variable_scope('primary'):
            # Actor: deterministic policy mu(s) outputs one action vector.
            self.proto = dense(embed, self.hps.actor_layers+[1], output='tanh', name='proto')
            # Critic: action value, Q(s, a)
            self.Q = tf.squeeze(dense(tf.concat([embed, tf.expand_dims(self.action, axis=1)], axis=1), self.hps.critic_layers+[1], output=None, name='Q'))
            # We want to train mu network to maximize Q value that is estimated by our critic;
            # this is only used for training.
            self.Qp = tf.squeeze(dense(tf.concat([embed, self.proto], axis=1), self.hps.critic_layers+[1], output=None, name='Q'))
            # primary variables
            self.proto_vars = self.scope_vars('primary/proto')
            self.Q_vars = self.scope_vars('primary/Q')
            self.primary_vars = self.proto_vars + self.Q_vars
            
        with tf.variable_scope('target'):
            # Clone target networks.
            self.proto_target = dense(embed_next, self.hps.actor_layers+[1], output='tanh', name='proto')
            self.Q_target = tf.squeeze(dense(tf.concat([embed_next, tf.expand_dims(self.action_next, axis=1)], axis=1), self.hps.critic_layers+[1], output=None, name='Q'))
            # target variables
            self.target_vars = self.scope_vars('target/proto') + self.scope_vars('target/Q')
            
        # sanity check
        assert len(self.primary_vars) == len(self.target_vars)
        
    def init_target_net(self):
        self.sess.run([v_t.assign(v) for v_t, v in zip(self.target_vars, self.primary_vars)])
        
    def update_target_net(self, tau=0.01):
        self.sess.run([v_t.assign((1.0 - tau) * v_t + tau * v) for v_t, v in zip(self.target_vars, self.primary_vars)])
        
    def _build_ops(self):
        self.lr_a = tf.placeholder(tf.float32, shape=None, name='learning_rate_actor')
        self.lr_c = tf.placeholder(tf.float32, shape=None, name='learning_rate_critic')
        
        with tf.variable_scope('Q_train'):
            self.Q_reg = tf.reduce_mean([tf.nn.l2_loss(x) for x in self.Q_vars])
            # use tf.stop_gradient() because we don't want to update the Q target net yet.
            y = self.reward + self.hps.gamma * self.Q_target * (1.0 - self.done)
            self.Q_loss = tf.reduce_mean(tf.square(tf.stop_gradient(y) - self.Q)) + 0.0001 * self.Q_reg
            self.Q_train_op = tf.train.AdamOptimizer(self.lr_c).minimize(self.Q_loss, var_list=self.Q_vars+self.embed_vars)
            
        with tf.variable_scope('proto_train'):
            self.proto_loss = -tf.reduce_mean(self.Qp)
            self.proto_train_op = tf.train.AdamOptimizer(self.lr_a).minimize(self.proto_loss, var_list=self.proto_vars+self.embed_vars)
        
        self.train_ops = tf.group(self.Q_train_op, self.proto_train_op)
          
        with tf.variable_scope('summary'):
            self.ep_reward = tf.placeholder(tf.float32, name='episode_reward')  # just for logging.
            self.summary = [
                tf.summary.scalar('loss/Q', self.Q_loss),
                tf.summary.scalar('loss/Q_reg', self.Q_reg),
                tf.summary.scalar('loss/proto', self.proto_loss),
                tf.summary.scalar('episode_reward', self.ep_reward)
            ]
            
            self.merged_summary = tf.summary.merge_all(key=tf.GraphKeys.SUMMARIES)
      
    def _generate_rollout(self, buffer, eps=0.0, rand=False):
        states = []
        masks = []
        auxiliaries = []
        avails = []
        actions = []
        next_states = []
        next_masks = []
        rewards = []
        dones = []
        
        logging.info('start rollout.')
        s, m  = self.env.reset()
        self.random_process.reset_states([len(s), 1])
        episode_reward = np.zeros([s.shape[0]], dtype=np.float32)
        done = np.zeros([s.shape[0]], dtype=np.bool)
        while not np.all(done):
            aux = self.model.get_auxiliary(s, m)
            avail = self.model.get_availability(s, m)
            action, env_action = self.act(s, m, aux, avail, eps=eps, rand=rand)
            s_next, m_next, re, done = self.env.step(env_action)
            ri = self.model.get_reward(s, m, env_action, s_next, m_next, done, self.env.y)
            r = re + ri
            if self.hps.detector_reward_coef > 0:
                rd = self.detector.get_reward(s, m, env_action, s_next, m_next, done, self.env.y)
                r = r + rd * self.hps.detector_reward_coef
            states.append(s)
            masks.append(m)
            auxiliaries.append(aux)
            avails.append(avail)
            actions.append(action)
            next_states.append(s_next)
            next_masks.append(m_next)
            rewards.append(r)
            dones.append(done)
            episode_reward += r
            s, m = s_next, m_next
        logging.info('rollout finished.')
        
        # record this batch
        logging.info('record this batch.')
        B = len(s)
        T = len(rewards)
        n_rec = B * T
        for t in range(T):
            for j in range(B):
                item = buffer.tuple_class(
                    states[t][j],
                    masks[t][j],
                    auxiliaries[t][j],
                    avails[t][j],
                    actions[t][j],
                    next_states[t][j],
                    next_masks[t][j],
                    rewards[t][j],
                    dones[t][j]
                )
                buffer.add(item)
        logging.info(f'record done: {n_rec} transitions added.')
        
        return np.mean(episode_reward), n_rec
        
    def _decay_epsilon_fn(self, n_iter):
        eps = self.hps.epsilon
        eps_drop_per_iter = self.hps.epsilon / self.hps.epsilon_decay_iters
        if n_iter > self.hps.warmup_iters:
            eps = max(self.hps.epsilon - eps_drop_per_iter * (n_iter - self.hps.warmup_iters), 0)
        
        return eps
            
    def run(self):
        self.init_target_net()
        
        BufferRecord = namedtuple('BufferRecord', 
            ['state', 'mask', 'aux', 'avail', 'action', 
             'state_next', 'mask_next', 'reward', 'done'])
        buffer = ReplayMemory(tuple_class=BufferRecord, capacity=self.hps.buffer_size)
        
        reward_history = []
        reward_averaged = []
        best_reward = -np.inf
        step = 0
        total_rec = 0
        
        for n_iter in range(self.hps.n_iters):
            eps = self._decay_epsilon_fn(n_iter)
            if self.hps.clean_buffer: buffer.clean()
            rand = False
            if n_iter < self.hps.warmup_iters:
                rand = True
            ep_reward, n_rec = self._generate_rollout(buffer, eps=eps, rand=rand)
            reward_history.append(ep_reward)
            reward_averaged.append(np.mean(reward_history[-10:]))
            total_rec += n_rec
            
            for batch in buffer.loop(self.hps.buffer_batch_size, self.hps.buffer_epochs):
                aux_next = self.model.get_auxiliary(batch['state_next'], batch['mask_next'])
                avail_next = self.model.get_availability(batch['state_next'], batch['mask_next'])
                action_next, _ = self.act(batch['state_next'], 
                                          batch['mask_next'], 
                                          aux_next, 
                                          avail_next, 
                                          is_target=True)
                _, summ_str = self.sess.run(
                    [self.train_ops, self.merged_summary],
                    feed_dict={self.lr_a: self.hps.lr_a,
                               self.lr_c: self.hps.lr_c,
                               self.state: batch['state'],
                               self.mask: batch['mask'],
                               self.auxiliary: batch['aux'],
                               self.action: batch['action'],
                               self.state_next: batch['state_next'], 
                               self.mask_next: batch['mask_next'], 
                               self.auxiliary_next: aux_next,
                               self.action_next: action_next,
                               self.reward: batch['reward'],
                               self.done: batch['done'],
                               self.ep_reward: np.mean(reward_history[-10:]) if reward_history else 0.0
                               }
                )
                self.update_target_net(self.hps.tau)
                self.writer.add_summary(summ_str, step)
                step += 1
                
            if self.hps.log_freq > 0 and (n_iter+1) % self.hps.log_freq == 0:
                logging.info("[iteration:{}/step:{}], best:{}, avg:{:.2f}, eps:{:.2f}; {} transitions.".format(
                    n_iter, step, np.max(reward_history), np.mean(reward_history[-10:]), eps, total_rec))

                data_dict = {
                    'reward': reward_history,
                    'reward_smooth10': reward_averaged,
                }
                plot_dict(f'{self.hps.exp_dir}/learning_curve.png', data_dict, xlabel='episode')
                
            if self.hps.eval_freq > 0 and n_iter % self.hps.eval_freq == 0:
                self.evaluate(load=False)
                
            if self.hps.save_freq > 0 and (n_iter+1) % self.hps.save_freq == 0:
                self.save()
                
            if np.mean(reward_history[-10:]) > best_reward:
                best_reward = np.mean(reward_history[-10:])
                self.save('best')
            
        # FINISH
        self.save()
        logging.info("[FINAL] episodes: {}, Max reward: {}, Average reward: {}".format(
            len(reward_history), np.max(reward_history), np.mean(reward_history)))
        
        # Evaluate
        self.evaluate()
        
    def evaluate(self, load=True, num_episodes=10):
        if load: self.load('best')
        
        metrics = defaultdict(list)
        acquisitions = []
        
        for _ in range(num_episodes):
            s, m = self.env.reset()
            self.random_process.reset_states([len(s), 1])
            episode_reward = np.zeros([s.shape[0]], dtype=np.float32)
            done = np.zeros([s.shape[0]], dtype=np.bool)
            while not np.all(done):
                aux = self.model.get_auxiliary(s, m)
                avail = self.model.get_availability(s, m)
                action, env_action = self.act(s, m, aux, avail)
                s_next, m_next, re, done = self.env.step(env_action)
                ri = self.model.get_reward(s, m, env_action, s_next, m_next, done, self.env.y)
                r = re + ri
                if self.hps.detector_reward_coef > 0:
                    rd = self.detector.get_reward(s, m, env_action, s_next, m_next, done, self.env.y)
                    r = r + rd * self.hps.detector_reward_coef
                episode_reward += r
                s, m = s_next, m_next
                
            metrics['episode_reward'].append(episode_reward)
            
            pred = self.model.predict(s, m)
            acc = (pred == self.env.y).astype(np.float32)
                    
            metrics['accuracy'].append(acc)

            acquisitions.append((s, m))
            
        # concat metrics
        average_metrics = defaultdict(float)
        for k, v in metrics.items():
            metrics[k] = np.concatenate(v)
            average_metrics[k] = np.mean(metrics[k])
        
        # log
        logging.info('#'*20)
        logging.info('evaluate:')
        for k, v in average_metrics.items():
            logging.info(f'{k}: {v}')

        return acquisitions