import numpy as np
import math
from typing import Callable
from abc import ABC, abstractmethod
from generative_model import Generative_Model,Hard_MDP_Wang
import multiprocessing
class RL_MB_AMDP:
    def __init__(self,generative_model,tmix,perform_value_iteration = True,Q_0 = None):
        self.model = generative_model
        self.tmix = tmix
        self.r_max = generative_model.r_max
        self.states = generative_model.states
        self.rewards = generative_model.rewards
        self.action_at_state = generative_model.action_at_state
        self.sa_pairs = generative_model.get_sa_pairs()
        self.is_mdp = self.model.is_mdp
        self.Q = Q_0
        self.R = None
        self.alpha_bar = None
        if self.Q is None:
            self.Q = np.zeros(len(self.sa_pairs))
        self.Q_star = None
        if perform_value_iteration:
            Q_star = self.value_iteration(self.model,0.9999)
            self.alpha_bar = np.mean(list(self.value_func_from_Q(Q_star).values()))*(1-0.9999)
    
    class Hard_MDP_Wang_rwd_purturbed_empirical(Generative_Model):
      @property
      def is_mdp(self):
        return True
      @property
      def states(self):
        return np.array([1, 2])
      @property
      def rewards(self):
        return np.array([0, 1])
      @property
      def sa_pairs(self):
        return [(1,0),(1,1),(2,0),(2,1)]
      @property
      def action_at_state(self):
        return self.generate_action_at_state()

      def __init__(self,tmix,purturb_size,samp_size):
          mdp = Hard_MDP_Wang(1/tmix)
          self.transition_map = {}
          for sa in self.sa_pairs:
              q = np.mean(mdp.generate_next_states(sa, samp_size)-1)
              self.transition_map[sa] = np.array([1-q,q])
          purturbs = np.random.rand(4,2)*purturb_size
          self.reward_map = {self.sa_pairs[0]: np.array([1+purturbs[0,0],0+purturbs[0,1]]),
                             self.sa_pairs[1]: np.array([1+purturbs[1,0],0+purturbs[1,1]]),
                             self.sa_pairs[2]: np.array([0+purturbs[2,0],1+purturbs[2,1]]),
                             self.sa_pairs[3]: np.array([0+purturbs[3,0],1+purturbs[3,1]])}
           
    def pmbp(self,eps,Jin = False):
        gamma = 1-eps/3/self.tmix
        purturb_size = (1-gamma)*self.tmix/4
        if Jin:
            samp_size = int(5/(1-gamma)**3/self.tmix**2)
        else:
            samp_size = int(5/(1-gamma)**2/self.tmix)
        emdp = self.Hard_MDP_Wang_rwd_purturbed_empirical(self.tmix,purturb_size,samp_size)
        dmdp_q = self.value_iteration(emdp,gamma)
        alpha_est = np.mean(list(self.value_func_from_Q(dmdp_q).values()))*(1-gamma)
        return alpha_est,samp_size
    
    
    
    
    def value_func_from_Q(self,Q = None,dpolicy = None):
        if dpolicy is None:
            if Q is None:
                Q = self.Q
            value_func = {}
            for state in list(self.action_at_state.keys()):
                value_func[state] = max([ Q[self.sa_pairs.index(tuple([state,a]))]  for a in self.action_at_state[state]])
            return value_func
        else:
            if Q is None:
                Q = self.Q
            value_func = {}
            for state in list(self.action_at_state.keys()):
                value_func[state] = Q[self.sa_pairs.index(tuple([state,dpolicy[state]]))]  
            return value_func
        
    def policy_from_Q(self,Q = None):
        if Q is None:
            Q = self.Q
        value_func = {}
        for state in list(self.action_at_state.keys()):
            value_func[state] = np.argmax([ Q[self.sa_pairs.index(tuple([state,a]))]  for a in self.action_at_state[state]])
        return value_func
    
    
    @staticmethod
    def value_iteration(mdp,gamma,tol = 1e-6):
        if not mdp.is_mdp:
            raise Exception("Cannot perform value iteration: input generative model is not a MDP.")
        Q_old = np.zeros(len(mdp.get_sa_pairs()))
        flag = True
        iter = 1
        R = None
        def vf(Q):
            value_func = []
            for state in mdp.states:
                value_func.append(max([ Q[mdp.sa_pairs.index(tuple([state,a]))]  for a in mdp.action_at_state[state]]))
            return np.array(value_func)
        
        def value_iteration_once(Q):
            val = vf(Q)
            VQ = []
            nonlocal R
            R_computed = not (R is None)
            if not R_computed:
                R_temp = []
            for sa in mdp.sa_pairs:
                states_dist_sa = mdp.transition_map[sa]
                rewards_dist_sa = mdp.reward_map[sa]
                if not R_computed:
                    R_sa = rewards_dist_sa@mdp.rewards
                    R_temp.append(R_sa)
                V_sa = states_dist_sa@val
                VQ.append(V_sa)
            if not R_computed:
                R = np.array(R_temp)
            return R + gamma*np.array(VQ)
            
        while flag and iter < 20000000:
            iter += 1
            Q_new = value_iteration_once(Q_old)
            flag = max(abs(Q_old-Q_new))>tol
            Q_old= Q_new
        if iter >= 20000000:
            print('value_iteration doesnt converge')
        return Q_new


    @staticmethod
    def dpolicy_evaluation(mdp,gamma,dpolicy,tol = 1e-6):
        if not mdp.is_mdp:
            raise Exception("Cannot perform value iteration: input generative model is not a MDP.")
        flag = True
        iter = 1
        R = None
        vf_old = np.zeros(len(mdp.states))
        def dpolicy_evaluation(vf):
            V = []
            nonlocal R
            R_computed = not (R is None)
            if not R_computed:
                R_temp = []
            for s in mdp.states:
                states_dist_s = mdp.transition_map[(s,dpolicy[s])]
                rewards_dist_s = mdp.reward_map[(s,dpolicy[s])]
                if not R_computed:
                    R_s = rewards_dist_s@mdp.rewards
                    R_temp.append(R_s)
                V_s = states_dist_s@vf
                V.append(V_s)
            if not R_computed:
                R = np.array(R_temp)
            return R + gamma*np.array(V)
            
        while flag and iter < 50000:
            iter += 1
            vf_new = dpolicy_evaluation(vf_old)
            flag = max(abs(vf_old-vf_new))>tol
            vf_old= vf_new
        if iter >= 50000:
            print('value_iteration doesnt converge')
        vf = {}
        idx = 0
        for s in mdp.states:
            vf[s] = vf_new[idx]
            idx += 1
        return vf
    

