import numpy as np
import math
from typing import Callable
from collections import Counter
from abc import ABC, abstractmethod

class DR_Q_learning:
    def __init__(self, generative_model, delta, gamma, perform_value_iteration = True, perform_relative_value_iteration=True, v_0 = None, q_0 = None, rv_0 = None, g_0 = None, max_iteration=5000):
        self.max_iteration = max_iteration
        self.delta = delta
        self.gamma = gamma
        self.model = generative_model
        self.r_max = generative_model.r_max
        self.states = generative_model.states
        self.action_at_state = generative_model.action_at_state
        self.sa_pairs = generative_model.sa_pairs
        self.is_mdp = generative_model.is_mdp
        self.rewards = generative_model.rewards
        self.v = v_0
        self.q = q_0
        self.r = None
        self.rvq = None
        self.rv = rv_0
        self.rv_star = None
        self.rv_old = None
        self.g_star = g_0
        if self.v is None:
            self.v = np.zeros(len(self.sa_pairs))
        self.v_star = None
        if self.q is None:
            self.q = np.zeros(len(self.sa_pairs))
        self.q_star = None
        if self.rv is None:
            self.rv = {s: 0 for s in self.states}
        if self.g_star is None:
            self.g_star = {s: 0 for s in self.states}
        #if perform_value_iteration:
        #   self.value_iteration(r_robust=False, v_robust=True, empirical=True, n_sample=10000, tolerance=1e-6)
        if perform_relative_value_iteration:
            self.relative_value_iteration(r_robust=False, v_robust=True, empirical=False, n_sample=10000, tolerance=1e-6)
        self.step_count = 0
        self.sample_used = 0
        self.epoch_used = 0
        self.default_ml = lambda l: math.ceil(2*2**l/(1-self.gamma)**2)
        self.default_lrfunction = lambda x: 1/(1+(1-gamma)*x)
        self.dist_of_sa = {}
        self.dist_of_r = {}

    # Compute the optimal value function for 1-step lookhead under kl-divergence
    def dual_opt(self, data, measure, alpha_max, opt_total=5e-6):
        barr = opt_total/100
        alpha_l = barr
        alpha_r = alpha_max + 2*barr
        r = 0.382

        pos_data = data[measure > 0]
        pos_measure = measure[measure > 0]
        essinf = min(pos_data)

        f = lambda alpha: -alpha * np.log(np.dot(pos_measure, np.exp((-1) * (pos_data - essinf)/alpha))) - alpha * self.delta
        df = lambda alpha: np.log(np.dot(pos_measure, np.exp((-1) * (pos_data - essinf)/alpha))) + essinf/alpha - self.delta \
            - np.dot(pos_measure, pos_data * np.exp((-1)*(pos_data-essinf)/alpha))/np.dot(pos_measure, np.exp((-1) * (pos_data - essinf)/alpha))/alpha
        diff = 2* opt_total

        # check if optimal multiplier alpha* is 0
        if len(pos_measure) == 1:
            return essinf
        kappa = sum(pos_measure[pos_data == essinf])
        if kappa >= np.exp(-self.delta):
            return essinf
        
        d_alpha = 1
        count = 1
        while diff > opt_total or d_alpha*(alpha_r-alpha_l) > opt_total:
            hat_alpha_l = alpha_l + r*(alpha_r-alpha_l)
            hat_alpha_r = alpha_l + (1-r)*(alpha_r-alpha_l)
            if f(hat_alpha_l) == np.inf:
                print("dual_opt has some numerical issues")
                return min(pos_data)
            diff = f(hat_alpha_l) - f(hat_alpha_r)
            if diff < 0:
                alpha_l = hat_alpha_l
                if count%5 == 0:
                    d_alpha = abs(df(alpha_l))
            else:
                alpha_r = hat_alpha_r
                if count%5 == 0:
                    d_alpha = abs(df(alpha_r))
            count += 1
        opt_alpha = (alpha_l + alpha_r)/2
        return f(opt_alpha) + essinf

    def value_function_from_q(self, q = None, d_policy = None):
        # d_policy: deterministic policy
        if d_policy is None:
            if q is None:
                q = self.q
            vf = {}
            for s in list(self.action_at_state.keys()):
                vf[s] = max([q[self.sa_pairs.index((s,a))] for a in self.action_at_state[s]])
            return vf
        else:
            if q is None:
                q = self.q
            vf = {}
            for s in list(self.action_at_state.keys()):
                vf[s] = q[self.sa_pairs.index((s, d_policy[s]))]
            return vf

    def relative_value_iteration_q(self, r_robust = False, v_robust = True, empirical = False, n_sample = 100, tolerance = 1e-6):
        if not self.is_mdp:
            raise Exception("Cannot perform value iteration: input generative model is not a MDP.")
        if self.rv_star is None:
            self.rv_star = self.rv
        g_old = self.g_star
        iteration = 0
        flag = True
        if not empirical:
            self.dist_of_sa = self.model.transition_map
            self.dist_of_r = self.model.reward_map
        else:
            for sa in self.sa_pairs:
                self.dist_of_sa[sa] = self.model.generate_empirical_distribution_s(sa, n_sample)
                self.dist_of_r[sa] = self.model.generate_empirical_distribution_r(sa, n_sample)
        while flag and iteration < n_sample:
            iteration += 1
            self.relative_value_iteration_q_once(r_robust, v_robust, n_sample= n_sample)

    def mlmc_robust_value(self, rv, dist_sa_to_s, phi, alpha_v_max, random_state=None, N_max=5):
        rng = random_state or np.random.default_rng()
        N = rng.geometric(phi) - 1
        if N > N_max:
            N = N_max
        
        num_samples = 2** (N+1)
        states = np.arange(len(dist_sa_to_s))
        samples = rng.choice(states, size=num_samples, p=dist_sa_to_s)

        p1 = np.zeros_like(dist_sa_to_s, dtype=float)
        p1[samples[0]] = 1.0

        p2 = np.bincount(samples, minlength=len(dist_sa_to_s)).astype(float)
        p2 /= p2.sum()

        odd_samples = samples[1::2]
        even_samples = samples[0::2]

        p3 = np.bincount(odd_samples, minlength=len(dist_sa_to_s)).astype(float)
        p3 /= p3.sum()

        p4 = np.bincount(even_samples, minlength=len(dist_sa_to_s)).astype(float)
        p4 /= p4.sum()

        term1 = self.dual_opt(rv, p1, alpha_v_max)
        term2 = self.dual_opt(rv, p2, alpha_v_max)
        term3 = self.dual_opt(rv, p3, alpha_v_max)
        term4 = self.dual_opt(rv, p4, alpha_v_max)

        result = term1 + (1.0/(phi*(1-phi)**N)) * (term2 - 0.5*(term3 + term4))
        return result


    def relative_value_iteration_q_once(self, r_robust = False, v_robust = True, n_sample = 100):
        if not self.is_mdp:
            raise ValueError("Relative value iteration is only applicable to MDPs.")
        alpha_v_max = max([abs(v) for v in self.rv.values()])/self.delta
        alpha_r_max = self.r_max/self.delta
        eta_p_rv = []
        r_temp = []
        for sa in self.sa_pairs:
            dist_sa_to_r = self.dist_of_r[sa]
            dist_sa_to_s = self.dist_of_sa[sa]
            if not r_robust:
                r_sa = np.dot(self.rewards, dist_sa_to_r)
            else:
                r_sa = self.dual_opt(self.rewards, dist_sa_to_r, alpha_r_max)
            r_temp.append(r_sa)
            sorted_rv = np.array([self.rv[key] for key in sorted(self.rv.keys())])
            if not v_robust:
                eta_p_rv_sa = np.dot(sorted_rv, dist_sa_to_s)
            else:
                # Option 1: directly use empirical distribution as the robust update value
                eta_p_rv_sa = self.dual_opt(sorted_rv, dist_sa_to_s, alpha_v_max)
                # Option 2: use mlmc robust value estimation
                # eta_p_rv_sa = self.mlmc_robust_value(sorted_rv, dist_sa_to_s, phi=1/2, alpha_v_max=alpha_v_max, N_max=100)
            eta_p_rv.append(eta_p_rv_sa)
        self.r = r_temp
        if self.rvq is None:
            self.rvq = np.zeros_like(self.r)
        temporal_difference = np.array(self.r) + np.array(eta_p_rv) - np.array(self.rvq)
        # Here the step size is set as (1/n_sample)^{0.95}, one can also use constant step size, but need to tune it properly, and may not converge
        self.rvq_new = np.array(self.rvq) + (1/n_sample) ** 0.95 * temporal_difference
        # self.rvq_new = np.array(self.rvq) + 0.0001 * temporal_difference
        self.g_star = self.value_function_from_q(temporal_difference)
        self.rvq = self.rvq_new
        self.rvq = self.rvq - min(self.rvq_new)
        self.rv = self.value_function_from_q(self.rvq)

    def relative_value_iteration(self, r_robust = False, v_robust = True, empirical = False, n_sample = 100, tolerance = 1e-6):
        if not self.is_mdp:
            raise Exception("Cannot perform value iteration: input generative model is not a MDP.")
        if self.rv_star is None:
            self.rv_star = self.rv
        g_old = self.g_star
        iteration = 0
        flag = True
        if not empirical:
            self.dist_of_sa = self.model.transition_map
            self.dist_of_r = self.model.reward_map
        else:
            for sa in self.sa_pairs:
                self.dist_of_sa[sa] = self.model.generate_empirical_distribution_s(sa, n_sample)
                self.dist_of_r[sa] = self.model.generate_empirical_distribution_r(sa, n_sample)
        while flag and iteration < self.max_iteration:
            iteration += 1
            self.relative_value_iteration_once(r_robust, v_robust)
            # print('self.g_star', self.g_star)
            # print('g_old', g_old)
            flag = self.span_semi_norm_difference(self.g_star, g_old)>tolerance
            g_old = self.g_star     
        if iteration == self.max_iteration:
            print("Relative Value iteration did not converge in finite steps")
        self.rv_star = self.rv
        self.get_relative_reward()
        #print('final g_star', self.g_star)

    def relative_value_iteration_once(self, r_robust = False, v_robust = True):
        if not self.is_mdp:
            raise Exception("Cannot perform value iteration: input generative model is not a MDP.")
        alpha_v_max = max([abs(v) for v in self.rv.values()])/self.delta
        alpha_r_max = self.r_max/self.delta
        eta_p_rv = []
        r_temp = []
        # print('self.rv', self.rv)
        for sa in self.sa_pairs:
            dist_sa_to_r = self.dist_of_r[sa]
            dist_sa_to_s = self.dist_of_sa[sa]
            if not r_robust:
                r_sa = np.dot(self.rewards, dist_sa_to_r)
            else:
                r_sa = self.dual_opt(self.rewards, dist_sa_to_r, alpha_r_max)
            r_temp.append(r_sa)
            sorted_rv = np.array([self.rv[key] for key in sorted(self.rv.keys())])
            if not v_robust:
                eta_p_rv_sa = np.dot(sorted_rv, dist_sa_to_s)
            else:
                eta_p_rv_sa = self.dual_opt(sorted_rv, dist_sa_to_s, alpha_v_max)
            eta_p_rv.append(eta_p_rv_sa)
        self.r = r_temp
        self.rvq = np.array(self.r) + np.array(eta_p_rv)
        self.rv_old = self.rv
        self.rv = self.value_function_from_q(self.rvq)
        min_rv = min(self.rv.values())
        self.g_star = {s: self.rv[s] - min_rv for s in self.states}

    def get_relative_reward(self):
        self.g_star = {s: self.rv[s] - self.rv_old[s] for s in self.states}

    def span_semi_norm_difference(self, v1={}, v2={}):
        difference = {s: v1[s] - v2[s] for s in self.states}
        span_norm = max(difference[s] for s in self.states) - min(difference[s] for s in self.states)
        return span_norm
    

    def reset(self):
        self.q = np.zeros(len(self.sa_pairs))
        self.v = np.zeros(len(self.sa_pairs))
        self.v_star= None
        self.q_star = None
        self.r = None
        self.rvq = None
        self.rv = {s: 0 for s in self.states}
        self.rv_star = None
        self.rv_old = None
        self.g_star = {s: 0 for s in self.states}
        self.step_count = 0
        self.sample_used = 0
        self.epoch_used = 0  