import numpy as np
import itertools
import pickle
from asymmetric_ac import AsymmetricAC
from npg import NPG
from q_learning import AsymmetricQ
from pomdp_env import sample, POMDP

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
finite_mem = 3

plt.rcParams["font.family"] = "Times New Roman"

class AsymmetricNPG:
    def __init__(self, p, Z, alpha, tau):
        Q = []
        self.pi = []
        self.as_space = []
        self.s_space = []
        self.p = p
        self.Z = Z
        self.alpha = alpha
        self.tau = tau
        for i in range(p.H + 1):
            shape_list = []
            for _ in range(min(i, Z)):
                shape_list.append(p.o_num)
                shape_list.append(p.a_num)
            shape_list.extend([p.o_num, p.s_num])
            as_list = [list(range(shape)) for shape in shape_list]
            as_space = list(itertools.product(*as_list))
            self.as_space.append(as_space)
            shape_list = shape_list[:-1]
            s_list = [list(range(shape)) for shape in shape_list]
            s_space = list(itertools.product(*s_list))
            self.s_space.append(s_space)
            Q_dict = dict()
            pi_dict = dict()
            for s in as_space:
                Q_dict[tuple(s)] = np.zeros(shape=(p.a_num,))
            for s in s_space:
                pi_dict[tuple(s)] = np.ones(shape=(p.a_num,)) / p.a_num
            Q.append(Q_dict)
            self.pi.append(pi_dict)
        self.est = dict()
        for state, a in itertools.product(list(range(p.s_num)), list(range(p.a_num))):
            self.est[(state, a)] = np.zeros(shape=(p.s_num, p.o_num))
        self.Q_list = [Q]

    def learn(self, K):
        traj_list = []
        r_sum_list = []
        for k in range(K):
            traj, r_sum = self.run_traj()
            r_sum_list.append(self.evaluate(100))
            print(r_sum_list[-1])
            traj_list.append(traj)
            self.update_Q(traj)
            self.update_pi()
        return r_sum_list

    def update_Q(self, traj):
        self.new_Q = self.Q_list[-1].copy()
        for h in range(self.p.H):
            self.est[(traj[4 * h], traj[4 * h + 2])][traj[4 * h + 4], traj[4 * h + 5]] += 1
        for i in range(self.p.H):
            h = self.p.H - 1 - i
            for s in self.as_space[h]:
                for a in range(self.p.a_num):
                    state = s[-1]
                    bonus = 1 / max(np.sum(self.est[(state, a)]), 1)
                    expectation = 0
                    if np.sum(self.est[(state, a)]) < 0.5:
                        emp_tran = np.ones(shape=(self.p.s_num, self.p.o_num)) / (self.p.s_num * self.p.o_num)
                    else:
                        emp_tran = self.est[(state, a)] / np.sum(self.est[(state, a)])
                    for state_prime, o_prime in itertools.product(list(range(self.p.s_num)), list(range(self.p.o_num))):
                        if emp_tran[state_prime, o_prime] < 0.01:
                            continue
                        new_s_for_pi = self.truncate(s[:-1] + (a, o_prime))
                        new_s = new_s_for_pi + (state_prime,)
                        for a_prime in range(self.p.a_num):
                            if self.new_Q[h + 1][new_s][a_prime] < 0.01:
                                continue
                            expectation += emp_tran[state_prime, o_prime] * self.pi[h + 1][new_s_for_pi][a_prime] * \
                                           self.new_Q[h + 1][new_s][a_prime]
                    self.new_Q[h][s][a] = self.p.reward[h][s[-1], a] + bonus + expectation
        self.Q_list.append(self.new_Q)

    def truncate(self, his):
        if len(his) > self.Z * 2 + 1:
            s = tuple(his[-(self.Z * 2 + 1):])
        else:
            s = tuple(his)
        return s

    def update_pi(self):
        Q = self.Q_list[-1]
        for i in range(self.p.H):
            h = self.p.H - 1 - i
            for s in self.s_space[h]:
                for a in range(self.p.a_num):
                    belief_weighted_q = 0
                    for state in range(self.p.s_num):
                        belief_weighted_q += self.p.belief(s, state, h <= self.Z) * Q[h][s + (state,)][a]
                    self.pi[h][s][a] *= np.exp(self.tau * belief_weighted_q / (h + 1))
                self.pi[h][s] /= self.pi[h][s].sum()

    def run_traj(self):
        s, o = self.p.reset()
        his = [o]
        r_sum = 0
        extended_hist = [s, o]
        for h in range(self.p.H):
            a = sample(self.pi[h][self.truncate(his)])
            s, o, r, _ = self.p.step(a)
            his += [a, o]
            extended_hist += [a, r, s, o]
            r_sum += r
        return extended_hist, r_sum

    def evaluate(self, K):
        r_sum = 0
        for k in range(K):
            r_sum += self.run_traj()[1]
        return r_sum / K