import numpy as np
import copy
import random
import math
from collections import deque
np.set_printoptions(precision=4,suppress=True)

state_idx = np.zeros((6,6))
idx = 0
for x_ in range(6):
    for y_ in range(6):
        state_idx[x_][y_] = idx
        idx += 1

class Policy:
    def __init__(self, n_s, n_a, bs):
        self.n_state = n_s
        self.n_action = n_a
        self.batch_size = bs
        self.theta = np.random.rand(n_s, n_a)
        self.weight = np.zeros(n_s)
        self.actions = np.array([i for i in range(n_a)])

        self.traj_buf = []

        self.traj_wd = deque(maxlen=bs)
        return

    def put_data(self, traj):
        self.traj_buf.append(traj)

    def feed_data(self, traj):
        self.traj_wd.append(traj)

    def get_state_index(self, s):
        x = int(s[0]) -1
        y = int(s[1]) -1
        return int(state_idx[x][y])

    def get_action_prob(self, s):
        s_idx = self.get_state_index(s)
        res_exp = np.exp(self.theta[s_idx])
        sum_res_exp = np.sum(res_exp)
        prob = res_exp / sum_res_exp
        return prob

    def get_action(self, s):
        prob = self.get_action_prob(s)

        random_num = np.random.rand()
        sum_prob = 0.0

        for i in range(len(self.actions)):
            sum_prob += prob[i]
            if sum_prob >= random_num:
                return self.actions[i], prob[i]

    def eval_action(self, s):
        prob = self.get_action_prob(s)
        action = np.argmax(prob)
        return action

    def get_derivative(self, s, a):
        # nabla log pi(a|s)
        der_theta = np.zeros((self.n_state, self.n_action))

        prob = self.get_action_prob(s)
        s_idx = self.get_state_index(s)
        der_theta[s_idx][a] = 1.
        der_theta[s_idx] -= prob
        return der_theta

    ###### reinforce + gini, sample n traj, update k times ######
    def reinforce_baseline_gini_k(self, lr_policy, lr_value, gamma, lam, ent_coef, k):
        for i in range(k):
            stop = self.reinforce_baseline_gini(lr_policy, lr_value, gamma, lam, ent_coef)
            if stop:
                #print('early stop', i)
                break
        
        '''clean buf'''
        self.traj_buf = []

    def reinforce_baseline_gini(self, lr_policy, lr_value, gamma, lam, ent_coef):
        '''return, and sum_der_logPi of each traj'''
        ret_lst = []
        sum_der_lst = []
        reinforce_der_lst = []

        reverse_kl_lst, forward_kl_lst, is_ratio_lst = [], [], []

        '''reinforce'''
        # use episode-wise instead (sum over an episode)
        
        for t in range(self.batch_size):
            traj = self.traj_buf[t]
            transition = copy.deepcopy(traj)
            transition.reverse()
            traj_len = len(transition)

            ret = 0.0
            sum_der = np.zeros((self.n_state, self.n_action))
            reinforce_der = np.zeros((self.n_state, self.n_action))

            old_pi_lst, current_pi_lst = [], []
            for i in range(traj_len):
                item = transition[i]
                ret = item[2] + gamma * ret

                old_pi_lst.append(item[3])
                pi_a = self.get_action_prob(item[0])[item[1]]
                current_pi_lst.append(pi_a)

                s_idx = self.get_state_index(item[0])
                value = self.weight[s_idx]
                delta = ret - value
                der_theta = self.get_derivative(item[0], item[1])
                sum_der += der_theta

                reinforce_der += pow(gamma, traj_len-1-i) * der_theta * delta

                # update value function step-wise
                self.weight[s_idx] += (lr_value / self.batch_size) * delta
            
            ret_lst.append(ret)
            sum_der_lst.append(sum_der)
            reinforce_der_lst.append(reinforce_der)

            # compute  IS
            old_pi = np.array(old_pi_lst)
            current_pi = np.array(current_pi_lst)
            log_ratio = np.log(current_pi) - np.log(old_pi)
            is_ratio_lst.append(np.exp(log_ratio.sum()))

        '''choose IS'''
        is_ratio = np.array(is_ratio_lst)
        is_idx = np.where((is_ratio<=1.3) & (is_ratio >= 0.7))
        is_ratio_choose = is_ratio[is_idx]
        ret_choose = np.array(ret_lst)[is_idx]
        reinforce_der_choose = np.array(reinforce_der_lst)[is_idx]
        sum_der_choose = np.array(sum_der_lst)[is_idx]
        choose_size = len(is_ratio_choose)

        #print('choose_size:', choose_size)

        '''reinforce'''
        der_rf = np.zeros((self.n_state, self.n_action))
        for i in range(choose_size):
            der_rf += reinforce_der_choose[i] * is_ratio_choose[i]

        '''gini'''
        der_gini = self.compute_gini_derivative_is(ret_choose, sum_der_choose, is_ratio_choose)
        
        '''add two gradient'''

        self.theta += (lr_policy / self.batch_size) * der_rf - lam * (lr_policy / (self.batch_size-1)) * der_gini

        '''entropy'''
        if (ent_coef is not None) and (ent_coef > 0):
            # choose selected trajectories
            all_sample = []
            for idx in is_idx[0]:
                all_sample += self.traj_buf[idx]
            random.shuffle(all_sample)

            sample_size = min(400, len(all_sample))
            der_ent = np.zeros((self.n_state, self.n_action))
            for i in range(sample_size):
                item = all_sample[i]
                pi = self.get_action_prob(item[0])
                for a in self.actions:
                    a = int(a)
                    der_theta_ = self.get_derivative(item[0], a)
                    der_ent -= (1 + np.log(pi[a])) * pi[a] * der_theta_

            self.theta += lr_policy * ent_coef * der_ent
            
        '''sample size'''
        if choose_size < 30:
            return True
        else:
            return False

        
    def compute_gini_derivative_is(self, ret_lst, sum_der_lst, is_ratio_lst):
        ret_lst = np.array(ret_lst)
        sort_R = np.sort(ret_lst)
        sort_idx = np.argsort(ret_lst)
        # compute integral CDF first
        diff = sort_R[1:] - sort_R[:-1]
        sample_size = len(ret_lst)
        x = np.linspace(start=1., stop=sample_size-1, num=sample_size-1)
        x /= sample_size
        diff = diff * x
        cumsum_diff = diff + np.sum(diff) - np.cumsum(diff)
        coef = 2. * cumsum_diff + sort_R[:-1] - sort_R[-1]

        der_gini = np.zeros((self.n_state, self.n_action))
        for t in range(sample_size-1):
            sort_t = sort_idx[t]
            sum_der = sum_der_lst[sort_t]
            is_ratio = is_ratio_lst[sort_t]
            der_gini += sum_der * (-1) * coef[t] * is_ratio
        return der_gini

    



        



  