import numpy as np
import math
from typing import Callable
from collections import Counter
from abc import ABC, abstractmethod

class DR_RL_empirical_kl:
    def __init__(self, generative_model, delta, gamma, perform_value_iteration = True, perform_relative_value_iteration=True, v_0 = None, q_0 = None, rv_0 = None, g_0 = None):
        self.delta = delta
        self.gamma = gamma
        self.model = generative_model
        self.r_max = generative_model.r_max
        self.states = generative_model.states
        self.action_at_state = generative_model.action_at_state
        self.sa_pairs = generative_model.sa_pairs
        self.is_mdp = generative_model.is_mdp
        self.rewards = generative_model.rewards
        self.v = v_0
        self.q = q_0
        self.r = None
        self.rvq = None
        self.rv = rv_0
        self.rv_star = None
        self.rv_old = None
        self.g_star = g_0
        if self.v is None:
            self.v = np.zeros(len(self.sa_pairs))
        self.v_star = None
        if self.q is None:
            self.q = np.zeros(len(self.sa_pairs))
        self.q_star = None
        if self.rv is None:
            self.rv = {s: 0 for s in self.states}
        if self.g_star is None:
            self.g_star = {s: 0 for s in self.states}
        if perform_value_iteration:
            self.value_iteration(r_robust=False, v_robust=True, empirical=True, n_sample=10000, tolerance=1e-6)
        if perform_relative_value_iteration:
            self.relative_value_iteration(r_robust=False, v_robust=True, empirical=False, n_sample=10000, tolerance=1e-6)
        self.step_count = 0
        self.sample_used = 0
        self.epoch_used = 0
        self.default_ml = lambda l: math.ceil(2*2**l/(1-self.gamma)**2)
        self.default_lrfunction = lambda x: 1/(1+(1-gamma)*x)
        self.dist_of_sa = {}
        self.dist_of_r = {}


    def one_vr_epoc(self, n_sample, kstar, ml: Callable = None, lrfunction: Callable = None, disp = False):
        n_sample = math.ceil(n_sample)
        kstar = math.ceil(kstar)
        if ml is None:
            ml = self.default_ml
        if lrfunction is None:
            lrfunction = self.default_lrfunction
        self.epoch_used += 1
        l = self.epoch_used
        qprev = self.Q
        recenter = self.apply_empirical_bellman(ml(l), qprev)
        for k in range(1, kstar+1):
            return
        
    def empirical_value_function(self, lr_function: Callable, n_sample, n=1, disp = False):
        self.n_step_sa(lr_function, n_sample, n, disp)
        vf = self.value_function_from_q()
        return vf, self.sample_used, self.step_count


    def n_step_sa(self, lr_function: Callable, n_sample, n=1, disp = False):
        n_sample = math.ceil(n_sample)
        n = math.ceil(n)
        for i in range(0, n):
            self.step_count += 1
            self.q = (1 - lr_function(self.step_count))*self.q + lr_function(self.step_count)*self.apply_empirical_bellman(n_sample)
        if disp:
            print("At step {}, the q-function for sa pairs is {}.".format(self.step_count, self.q))
        return self.q, self.sample_used, self.step_count

    def one_step_sa(self, lr, n_sample, disp = False):
        n_sample = math.ceil(n_sample)
        self.step_count += 1
        self.q = (1 - lr)*self.q + lr*self.apply_empirical_bellman(n_sample)
        if disp:
            print("At step {}, the q-function for sa pairs is {}.".format(self.step_count, self.q))
        return self.q, self.step_count
    
    # This funtion returns the robust empirical Bellman operator for each sa pair
    # n_sample: number of samples to generate
    # q: q-function to use
    # r_robust, v_robust: whether to apply the robust setting
    def apply_empirical_bellman(self, n_sample, q = None, r_robust = False, v_robust = True):
        if q is None:
            q = self.q
        self.sample_used += n_sample* 2
        vf = self.value_function_from_q()
        bell_q = []
        for sa in self.sa_pairs:
            alpha_r_max = self.r_max/self.delta
            if not r_robust:
                r_samp = self.model.generate_reward(sa, n_sample, False)
                r_sa = np.mean(r_samp)
            else:
                r_samp = self.model.generate_reward(sa, n_sample, True)
                r_sa = self.dual_opt(r_samp[0], r_samp[1], alpha_r_max)
            alpha_v_max = max([abs(v) for v in vf.values()])/self.delta
            if not v_robust:
                s_samp = self.model.generate_state(sa, n_sample, False)
                v_sa = np.mean(np.array(vf[s] for s in s_samp[0]))
            else:
                s_samp = self.model.generate_state(sa, n_sample, True)
                v_data = np.array([vf[s] for s in s_samp[0]])
                v_sa = self.dual_opt(v_data, s_samp[1], alpha_v_max)
            q_sa = r_sa + self.gamma*v_sa
            bell_q.append(q_sa)
        return np.array(bell_q)
    
    # Compute the optimal value function for 1-step lookhead under kl-divergence
    def dual_opt(self, data, measure, alpha_max, opt_total=5e-6):
        barr = opt_total/100
        alpha_l = barr
        alpha_r = alpha_max + 2*barr
        r = 0.382

        pos_data = data[measure > 0]
        pos_measure = measure[measure > 0]
        essinf = min(pos_data)

        f = lambda alpha: -alpha * np.log(np.dot(pos_measure, np.exp((-1) * (pos_data - essinf)/alpha))) - alpha * self.delta
        df = lambda alpha: np.log(np.dot(pos_measure, np.exp((-1) * (pos_data - essinf)/alpha))) + essinf/alpha - self.delta \
            - np.dot(pos_measure, pos_data * np.exp((-1)*(pos_data-essinf)/alpha))/np.dot(pos_measure, np.exp((-1) * (pos_data - essinf)/alpha))/alpha
        diff = 2* opt_total

        # check if optimal multiplier alpha* is 0
        if len(pos_measure) == 1:
            return essinf
        kappa = sum(pos_measure[pos_data == essinf])
        if kappa >= np.exp(-self.delta):
            return essinf
        
        d_alpha = 1
        count = 1
        while diff > opt_total or d_alpha*(alpha_r-alpha_l) > opt_total:
            hat_alpha_l = alpha_l + r*(alpha_r-alpha_l)
            hat_alpha_r = alpha_l + (1-r)*(alpha_r-alpha_l)
            if f(hat_alpha_l) == np.inf:
                print("dual_opt has some numerical issues")
                return min(pos_data)
            diff = f(hat_alpha_l) - f(hat_alpha_r)
            if diff < 0:
                alpha_l = hat_alpha_l
                if count%5 == 0:
                    d_alpha = abs(df(alpha_l))
            else:
                alpha_r = hat_alpha_r
                if count%5 == 0:
                    d_alpha = abs(df(alpha_r))
            count += 1
        opt_alpha = (alpha_l + alpha_r)/2
        return f(opt_alpha) + essinf
    
    def dual_opt_chi_square(self, data, measure, alpha_max, opt_total=5e-6):
        barr = opt_total/100
        alpha_l = barr
        alpha_r = alpha_max + 2*barr
        r = 0.382

        c_delta = np.sqrt(1+self.delta)
        pos_data = data[measure > 0]
        pos_measure = measure[measure > 0]

        def f(alpha):
            diff = np.maximum(alpha - pos_data, 0)
            return alpha - c_delta * np.sqrt(np.dot(pos_measure, diff**2))
        def df(alpha, h=1e-6):
            return (f(alpha+h) - f(alpha-h))/(2*h)
        
        diff = 2 * opt_total
        d_alpha = 1
        count = 1

        while diff > opt_total or d_alpha*(alpha_r-alpha_l) > opt_total:    
            hat_alpha_l = alpha_l + r*(alpha_r-alpha_l)
            hat_alpha_r = alpha_l + (1-r)*(alpha_r-alpha_l)
            if np.isinf(f(hat_alpha_l)) or np.isnan(f(hat_alpha_l)):
                print("Numerical issues detected at alpha =", hat_alpha_l)
                hat_alpha_l = np.clip(hat_alpha_l, alpha_l, alpha_r)
            diff = f(hat_alpha_l) - f(hat_alpha_r)

            if diff < 0:
                alpha_l = hat_alpha_l
                if count%5 == 0:
                    d_alpha = abs(df(alpha_l))
            else:
                alpha_r = hat_alpha_r
                if count%5 == 0:
                    d_alpha = abs(df(alpha_r))
            count += 1
        opt_alpha = (alpha_l + alpha_r)/2
        return f(opt_alpha)
    
    def apply_robust_empirical_bellman_to_list(self, n_sample, qs:list, r_robust = False, v_robust = True):
        qs_out = [np.array([]) for q in qs]
        self.sample_used += n_sample*2
        for sa in self.sa_pairs:
            alpha_r_max = self.r_max/self.delta
            if not r_robust:
                r_samp = self.model.generate_reward(sa, n_sample, False)
                r_sa = np.mean(r_samp)
            else:
                r_samp = self.model.generate_reward(sa, n_sample, True)
                r_sa = self.dual_opt(r_samp[0], r_samp[1], alpha_r_max)
            if not v_robust:
                s_samp = self.model.generate_state(sa, n_sample, False)
            else:
                s_samp = self.model.generate_state(sa, n_sample, True)
            for q_idx in range(len(qs)):
                vf = self.value_function_from_q(qs[q_idx])
                bell_q = qs_out[q_idx]
                alpha_v_max = max([abs(v) for v in vf.values()])/self.delta
                if not v_robust:
                    v_sa = np.mean(np.array(vf[s] for s in s_samp[0]))
                else:
                    v_data = np.array([vf[s]] for s in s_samp[0])
                    v_sa = self.dual_opt(v_data, s_samp[1], alpha_v_max)
                bell_q_sa = r_sa + self.gamma*v_sa
                qs_out[q_idx] = np.append(bell_q, bell_q_sa)
        return qs_out

    def strive_optimal_policy_from_q(self, q = None):
        if q is None:
            q = self.q
        opt_policy = {}
        for s in list(self.action_at_state.keys()):
            opt_policy = np.argmax([q[self.sa_pairs.index((s,a))] for a in self.action_at_state[s]])
        return opt_policy
    
    def reset(self):
        self.q = np.zeros(len(self.sa_pairs))
        self.v = np.zeros(len(self.sa_pairs))
        self.v_star= None
        self.q_star = None
        self.r = None
        self.rvq = None
        self.rv = {s: 0 for s in self.states}
        self.rv_star = None
        self.rv_old = None
        self.g_star = {s: 0 for s in self.states}
        self.step_count = 0
        self.sample_used = 0
        self.epoch_used = 0        

    # value_iteration: used to compute the optimal value function for the nominal MDP with the transition map.
    def value_iteration(self, r_robust = False, v_robust = True, empirical = False, n_sample = 100, tolerance = 1e-6):
        if not self.is_mdp:
            raise Exception("Cannot perform value iteration: input generative model is not a MDP.")
        if self.q_star is None:
            self.q_star = self.q
        q_old = self.q_star
        iteration = 0
        flag = True
        dist_of_sa = {}
        dist_of_r = {}
        if not empirical:
            dist_of_sa = self.model.transition_map
            dist_of_r = self.model.reward_map
        else:
            for sa in self.sa_pairs:
                dist_of_sa[sa] = self.model.generate_empirical_distribution_s(sa, n_sample)
                dist_of_r[sa] = self.model.generate_empirical_distribution_r(sa, n_sample)
        while flag and iteration < 5000:
            iteration += 1
            self.value_iteration_once(dist_of_sa, dist_of_r, r_robust, v_robust)
            flag = max(abs(q_old - self.q_star))>tolerance
            q_old = self.q_star
        if iteration == 5000:
            print("Value iteration did not converge in finite steps")

    def value_iteration_once(self,dist_of_sa, dist_of_r, r_robust = False, v_robust = True):
        if not self.is_mdp:
            raise Exception("Cannot perform value iteration: input generative model is not a MDP.")
        if self.q_star is None:
            self.q_star = self.q
        vf = self.value_function_from_q(self.q_star)
        alpha_v_max = max([abs(v) for v in vf.values()])/self.delta
        alpha_r_max = self.r_max/self.delta
        eta_v = []
        r_computed = not (self.r is None)
        if not r_computed:
            r_temp = []
        for sa in self.sa_pairs:
            dist_sa_to_r = dist_of_r[sa]
            dist_sa_to_s = dist_of_sa[sa]
            if not r_computed:
                if not r_robust:
                    r_sa = np.dot(self.rewards, dist_sa_to_r)
                else:
                    r_sa = self.dual_opt(self.rewards, dist_sa_to_r, alpha_r_max)
                r_temp.append(r_sa)
            sorted_v = np.array([vf[key] for key in sorted(vf.keys())])
            if not v_robust:
                v_sa = np.dot(sorted_v, dist_sa_to_s)
            else:
                v_sa = self.dual_opt(sorted_v, dist_sa_to_s, alpha_v_max)
            eta_v.append(v_sa)
        if not r_computed:
            self.r = np.array(r_temp)
        self.q_star = self.r + self.gamma * np.array(eta_v)
        self.v_star = self.value_function_from_q(self.q_star)
        

    def value_function_from_q(self, q = None, d_policy = None):
        # d_policy: deterministic policy
        if d_policy is None:
            if q is None:
                q = self.q
            vf = {}
            for s in list(self.action_at_state.keys()):
                vf[s] = max([q[self.sa_pairs.index((s,a))] for a in self.action_at_state[s]])
            return vf
        else:
            if q is None:
                q = self.q
            vf = {}
            for s in list(self.action_at_state.keys()):
                vf[s] = q[self.sa_pairs.index((s, d_policy[s]))]
            return vf


    def relative_value_iteration(self, r_robust = False, v_robust = True, empirical = False, n_sample = 100, tolerance = 1e-6):
        # print('relatvie value_iteration with parameters r_robust = {}, v_robust = {}, empirical = {}, n_sample = {}, tolerance = {}'.format(r_robust, v_robust, empirical, n_sample, tolerance))
        if not self.is_mdp:
            raise Exception("Cannot perform value iteration: input generative model is not a MDP.")
        if self.rv_star is None:
            self.rv_star = self.rv
        g_old = self.g_star
        iteration = 0
        flag = True
        if not empirical:
            self.dist_of_sa = self.model.transition_map
            self.dist_of_r = self.model.reward_map
        else:
            for sa in self.sa_pairs:
                self.dist_of_sa[sa] = self.model.generate_empirical_distribution_s(sa, n_sample)
                self.dist_of_r[sa] = self.model.generate_empirical_distribution_r(sa, n_sample)
        while flag and iteration < 5000:
            iteration += 1
            self.relative_value_iteration_once(r_robust, v_robust)
            flag = self.span_semi_norm_difference(self.g_star, g_old)>tolerance
            g_old = self.g_star
        if iteration == 5000:
            print("Relative Value iteration did not converge in finite steps")
        self.rv_star = self.rv
        self.get_relative_reward()
        #print('this is relative value iteration final g_star is ', self.g_star)
        
    def relative_value_iteration_once(self, r_robust = False, v_robust = True):
        if not self.is_mdp:
            raise Exception("Cannot perform value iteration: input generative model is not a MDP.")
        alpha_v_max = max([abs(v) for v in self.rv.values()])/self.delta
        alpha_r_max = self.r_max/self.delta
        eta_p_rv = []
        r_temp = []
        for sa in self.sa_pairs:
            dist_sa_to_r = self.dist_of_r[sa]
            dist_sa_to_s = self.dist_of_sa[sa]
            if not r_robust:
                r_sa = np.dot(self.rewards, dist_sa_to_r)
            else:
                r_sa = self.dual_opt(self.rewards, dist_sa_to_r, alpha_r_max)
            r_temp.append(r_sa)
            sorted_rv = np.array([self.rv[key] for key in sorted(self.rv.keys())])
            if not v_robust:
                eta_p_rv_sa = np.dot(sorted_rv, dist_sa_to_s)
            else:
                eta_p_rv_sa = self.dual_opt(sorted_rv, dist_sa_to_s, alpha_v_max)
            eta_p_rv.append(eta_p_rv_sa)
        self.r = r_temp
        self.rvq = np.array(self.r) + np.array(eta_p_rv)
        self.rv_old = self.rv
        self.rv = self.value_function_from_q(self.rvq)
        min_rv = min(self.rv.values())
        self.g_star = {s: self.rv[s] - min_rv for s in self.states}

    def span_semi_norm_difference(self, v1={}, v2={}):
        difference = {s: v1[s] - v2[s] for s in self.states}
        span_norm = max(difference[s] for s in self.states) - min(difference[s] for s in self.states)
        return span_norm
    
    def get_relative_reward(self):
        self.g_star = {s: self.rv[s] - self.rv_old[s] for s in self.states}


    # The below code is carried for Anchored_value_iteration
    def anchored_relative_value_iteration(self, r_robust = False, v_robust = True, empirical = False, n_sample = 100, tolerance = 1e-6, xi = 0.1):
        if not self.is_mdp:
            raise Exception("Cannot perform value iteration: input generative model is not a MDP.")
        if self.rv_star is None:
            self.rv_star = self.rv
        g_old = self.g_star
        iteration = 0
        flag = True
        if not empirical:
            self.dist_of_sa = self.model.transition_map
            self.dist_of_r = self.model.reward_map
        else:
            for sa in self.sa_pairs:
                self.dist_of_sa[sa] = self.model.generate_empirical_distribution_s(sa, n_sample)
                additional_vector = np.zeros_like(self.dist_of_sa[sa])
                additional_vector[0] = xi
                self.dist_of_sa[sa] = (1-xi)*self.dist_of_sa[sa] + additional_vector
                self.dist_of_r[sa] = self.model.generate_empirical_distribution_r(sa, n_sample)
        while flag and iteration < 5000:
            iteration += 1
            rv_old = self.rv
            self.anchored_relative_value_iteration_once(r_robust, v_robust)
            flag = self.span_semi_norm_difference(self.rv, rv_old)>tolerance
            g_old = self.g_star
        if iteration == 5000:
            print("Anchored Relative Value iteration did not converge in finite steps")
        self.rv_star = self.rv
        self.get_relative_reward()
        print('anchored relative value iteration final g_star is ', self.g_star)

    def anchored_relative_value_iteration_once(self, r_robust = False, v_robust = True):
        if not self.is_mdp:
            raise Exception("Cannot perform value iteration: input generative model is not a MDP.")
        alpha_v_max = max([abs(v) for v in self.rv.values()])/self.delta
        alpha_r_max = self.r_max/self.delta
        eta_p_rv = []
        r_temp = []
        for sa in self.sa_pairs:
            dist_sa_to_r = self.dist_of_r[sa]
            dist_sa_to_s = self.dist_of_sa[sa]
            if not r_robust:
                r_sa = np.dot(self.rewards, dist_sa_to_r)
            else:
                r_sa = self.dual_opt(self.rewards, dist_sa_to_r, alpha_r_max)
            r_temp.append(r_sa)
            sorted_rv = np.array([self.rv[key] for key in sorted(self.rv.keys())])
            if not v_robust:
                eta_p_rv_sa = np.dot(sorted_rv, dist_sa_to_s)
            else:
                eta_p_rv_sa = self.dual_opt(sorted_rv, dist_sa_to_s, alpha_v_max)
            eta_p_rv.append(eta_p_rv_sa)
        self.r = r_temp
        self.rvq = np.array(self.r) + np.array(eta_p_rv)
        self.rv_old = self.rv
        self.rv = self.value_function_from_q(self.rvq)
