import os
import time
import logging
from pyomo.environ import *
import numpy as np
from typing import Callable, Optional, Union
from scipy.optimize import fsolve
from functools import partial
from multiprocessing import Pool

executable_path = r"path/to/ipopt.exe"
if not os.path.exists(executable_path):
    executable_path = None
# ================================================================== #
# ⚠️ Settings for solver
# ================================================================== #
if executable_path:
    solver = SolverFactory('ipopt', executable=executable_path)
else:
    solver = SolverFactory('ipopt')
solver.options['max_iter'] = 10000
solver.options['halt_on_ampl_error'] = 'yes'
logger = logging.getLogger(__name__)


def kl_divergence(model):
    r"""
    D_{\rm KL}(P\|Q) = \sum_i P(i)\log\frac{P(i)}{Q(i)}
    """
    # NOTE: P_sas' -> 0 will cause numerical error
    return sum((model.P_sa[s]) * log((model.P_sa[s] + 1e-8) / model.P_sa_nominal[s]) for s in model.S)
    # return sum(log((model.P_sa[s] / model.P_sa_nominal[s]) ** model.P_sa[s]) for s in model.S)


def f_divergence(model):
    r"""
    D_{\rm f}(P\|Q) = \sum_i Q(i)f\left(\frac{P(i)}{Q(i)}\right)
    """
    return sum(
        model.P_sa_nominal[s] * (
                (model.P_sa[s] / model.P_sa_nominal[s]) ** model.k
                - model.k * (model.P_sa[s] / model.P_sa_nominal[s])
                + model.k - 1
        ) / (model.k * (model.k - 1))
        for s in model.S)


def constraint_P(model):
    r"""
    P constraint in IPOPT: sum P = 1
    """
    return sum(model.P_sa[s] for s in model.S) == 1.0


def constraint_u(model):
    r"""
    u constraint in IPOPT: R(s,a)+\gamma P_{s,a}v\leq u
    """
    return sum(model.R_sa[s] * model.P_sa[s] for s in model.S) + \
        model.gamma * sum(model.P_sa[s] * model.v[s] for s in model.S) <= model.u


def q_inv_solver(u: float, v: np.ndarray, R_sa: np.ndarray, P_sa_nominal: np.ndarray, gamma: float,
                 divergence: str, k: float = 2.0):
    r"""
    Solve optimization problem:
    $$
    q_{s,a}^{-1}(u,v) = \min_{\begin{gathered}P_{s,a}\in\Delta(S)\\R(s,a)+\gamma P_{s,a}v\leq u\end{gathered}}~D_f(P_{s,a}\|\overline{P}_{s,a})
    $$

    :divergence: kl divergence or f_divergence
    """
    assert divergence in {'kl', 'f'}
    S = P_sa_nominal.shape[0]
    # ================================================================== #
    # 🚀 Initialize model, var, and parameters
    # ================================================================== #
    model = ConcreteModel()
    # since P_sa << P_sa_nominal, there is no need to optimize index where P_sa_nominal=0
    model.S = Set(initialize=np.where(P_sa_nominal != 0)[0].tolist())
    model.P_sa_nominal = Param(model.S, initialize=lambda model, s: P_sa_nominal[s])
    model.v = Param(model.S, initialize=lambda model, s: v[s])
    model.R_sa = Param(model.S, initialize=lambda model, s: R_sa[s])
    model.u = Param(initialize=u)
    model.gamma = Param(initialize=gamma)
    model.P_sa = Var(model.S, domain=NonNegativeReals, initialize=model.P_sa_nominal)
    if divergence == 'f':
        model.k = Param(initialize=k)
    # ================================================================== #
    # 🚀 Add objective and constraint
    # ================================================================== #
    if divergence == 'kl':
        model.objective = Objective(expr=kl_divergence(model), sense=minimize)
    elif divergence == 'f':
        model.objective = Objective(expr=f_divergence(model), sense=minimize)
    model.constraint_P = Constraint(rule=constraint_P)
    model.constraint_u = Constraint(rule=constraint_u)
    # call IPOPT to solve the optimization problem
    # tee=True will print the solver output
    results = solver.solve(model, tee=False)
    return value(model.objective)


def compute_u_range_single_action(v: np.ndarray, R_sa: np.ndarray, P_sa_nominal: np.ndarray, gamma: float):
    r"""
    We must make sure that q^{-1}(u,v) is not infeasible, so we need to compute u range first.
    u_min_a = \arg\min R(s,a)+\gamma P_{s,a}v
    u_max_a = \arg\max R(s,a)+\gamma P_{s,a}v
    u_min = \max_{a\in\mathcal{A}} u_min
    u_max = \max_{a\in\mathcal{A}} u_max
    """
    # ================================================================== #
    # 🚀 Initialize model, var, and parameters
    # ================================================================== #
    model = ConcreteModel()
    # since P_sa << P_sa_nominal, there is no need to optimize index where P_sa_nominal=0
    model.S = Set(initialize=np.where(P_sa_nominal != 0)[0].tolist())
    model.v = Param(model.S, initialize=lambda model, s: v[s])
    model.R_sa = Param(model.S, initialize=lambda model, s: R_sa[s])
    model.gamma = Param(initialize=gamma)
    model.P_sa = Var(model.S, domain=NonNegativeReals, initialize=1)

    # ================================================================== #
    # 🚀 Add objective function
    # ================================================================== #
    def f():
        return sum(model.R_sa[s] * model.P_sa[s] for s in model.S) + \
            model.gamma * sum(model.P_sa[s] * model.v[s] for s in model.S)

    model.objective = Objective(expr=f(), sense=minimize)
    model.constraint_P = Constraint(rule=constraint_P)
    # ================================================================== #
    # 🚀 Solve u_min and u_max
    # ================================================================== #
    solver.solve(model, tee=False)
    u_min = value(model.objective)
    # first delete min opt problem, then and max opt problem
    model.del_component("objective")
    model.objective = Objective(expr=f(), sense=maximize)
    solver.solve(model, tee=False)
    u_max = value(model.objective)
    return u_min, u_max


def compute_u_range(v: np.ndarray, R_s: np.ndarray, P_s: np.ndarray, gamma: float):
    u_min = -float("inf")
    u_max = -float("inf")
    A, S = P_s.shape
    for a in range(A):
        u_min_a, u_max_a = compute_u_range_single_action(v, R_s[a], P_s[a], gamma)
        u_min = max(u_min, u_min_a)
        u_max = max(u_max, u_max_a)
    return u_min, u_max


def process_state(s, A, v, v_diff, i, eps, gamma, kappa, P, r, divergence, k):
    # ================================================================== #
    # 🚀 Init u_min and u_max
    # ================================================================== #
    u_min, u_max = compute_u_range(v, r[s], P[s], gamma)
    use_v_diff = [False, False]
    new_u_min = np.maximum(u_min, v[s] - v_diff)
    new_u_max = np.minimum(u_max, v[s] + v_diff)
    if new_u_min != u_min:
        use_v_diff[0] = True
    if new_u_max != u_max:
        use_v_diff[1] = True
    u_min, u_max = new_u_min, new_u_max
    u_copy = [u_min, u_max]

    # ================================================================== #
    # 🔍 Bisection loop to find optimal u_mid
    # ================================================================== #
    cnt = 0
    start = time.time()
    xi_s = np.zeros(A)

    while u_max - u_min > max(2 * eps, 2 * gamma ** i):
        u_mid = u_min + (u_max - u_min) / 2
        cnt += 1
        # ================================================================== #
        # 🚀 Compute q^{-1}_{s,a}
        # ================================================================== #
        for a in range(A):
            xi_s[a] = q_inv_solver(u_mid, v, r[s, a], P[s, a], gamma, divergence, k)
        m = np.sum(xi_s)
        if m <= kappa:
            u_max = u_mid
        else:
            u_min = u_mid

    end = time.time()
    logging.info(f"step{i}, s: {s}, u_step: {cnt}, u_time: {end - start}")
    u_mid = u_min + (u_max - u_min) / 2

    # judge update v_change or not
    s_v_change = not ((np.abs(u_mid - u_copy[0]) <= max(eps, gamma ** i) and use_v_diff[0]) or
                      (np.abs(u_mid - u_copy[1]) <= max(eps, gamma ** i) and use_v_diff[1]))

    return s, u_mid, xi_s, s_v_change


def bisection(P: np.ndarray, r: np.ndarray, gamma: float, v: np.ndarray,
              kappa: float, divergence: str, k: float = 2.0,
              max_iteration=1000, eps=1e-8):
    # ================================================================== #
    # 🔍 Parameter Validation and Initialization
    # ================================================================== #
    assert len(P.shape) == 3
    S, A = P.shape[0], P.shape[1]
    assert P.shape == (S, A, S) and r.shape == (S, A, S)
    assert v.shape == (S,)
    # ================================================================== #
    # 🔁 Main loop
    # ================================================================== #
    v_diff = float("inf")
    xi = np.zeros(shape=(S, A))
    for i in range(max_iteration):
        v_change = True
        v_new = np.zeros_like(v)

        # ================================================================== #
        # 🔁 State Space Traversal
        # ================================================================== #

        with Pool(processes=8) as pool:
            # use partial to fix the parameters
            worker = partial(process_state, A=A,
                             v=v, v_diff=v_diff, i=i, eps=eps, gamma=gamma,
                             kappa=kappa, P=P, r=r, divergence=divergence, k=k)

            # process all states in parallel
            results = pool.map(worker, range(S))
        for s, u_mid, xi_s, s_v_change in results:
            v_new[s] = u_mid
            xi[s] = xi_s
            v_change &= s_v_change
        # ================================================================== #
        # 📉 Check for convergence
        # ================================================================== #
        # v_diff = np.linalg.norm(v_new - v)
        if v_change is False:
            print(f"v not change v_diff: {v_diff}, new_diff:{np.linalg.norm(v_new - v)}")
            v_diff *= 2
        else:
            v_diff = np.linalg.norm(v_new - v)
        logging.info(f"\033[32mstep{i}, v_diff: {v_diff}, {np.linalg.norm(v_new - v)}, curr_v={v_new}\033[0m")
        if np.linalg.norm(v_new - v) < 10 * eps:
            break
        v = v_new
        if i == max_iteration - 1:
            print("reach max iteration")
    return v, xi


def DRMDP_objective(model):
    objective = sum(model.pi[a] * sum(
        model.P_s[a, sp] * (model.R[a, sp] + model.gamma * model.v[sp])
        for sp in model.S if model.P_nominal[a, sp] != 0
    ) for a in model.A)
    return objective


def constraint_kl(model):
    return sum(sum(model.P_s[a, sp] * log((model.P_s[a, sp] + 1e-8) / model.P_nominal[a, sp]) for sp in model.S if
                   model.P_nominal[a, sp] != 0) for a in model.A) <= model.kappa


def check_policy_is_true(v: np.ndarray, P_nominal: np.ndarray, R: np.ndarray, pi: np.ndarray, gamma: float,
                         kappa: float):
    """
    Use Bellman Update to check if policy is right.
    """
    S, A = P_nominal.shape[0], P_nominal.shape[1]
    v_new = []
    solver.options['tol'] = 1e-14
    for s in range(S):
        model = ConcreteModel()
        model.S = Set(initialize=range(S))
        model.A = Set(initialize=range(A))
        model.R = Param(model.A, model.S, initialize=lambda model, a, sp: R[s, a, sp])
        model.P_nominal = Param(model.A, model.S, initialize=lambda model, a, sp: P_nominal[s, a, sp])
        model.pi = Param(model.A, initialize=lambda model, a: pi[s, a])
        model.v = Param(model.S, initialize=lambda model, sp: v[sp])
        model.gamma = Param(initialize=gamma)
        model.kappa = Param(initialize=kappa)
        model.P_s = Var(model.A, model.S, domain=NonNegativeReals, initialize=model.P_nominal)

        model.objective = Objective(rule=DRMDP_objective, sense=minimize)
        model.constraint_divergence = Constraint(rule=constraint_kl)
        model.constraint_P = ConstraintList()
        for a in model.A:
            model.constraint_P.add(expr=sum(model.P_s[a, sp] for sp in model.S if model.P_nominal[a, sp] != 0) == 1)
        results = solver.solve(model, tee=False)
        v_new.append(value(model.objective))
        P_s = np.zeros_like(P_nominal[s])
        for a in model.A:
            for sp in model.S:
                if model.P_nominal[a, sp] != 0:
                    P_s[a, sp] = value(model.P_s[a, sp])
    #     print(P_s)
    print(np.array(v_new))
    print(v)
    del solver.options['tol']
    return np.array(v_new)


def kl_dual_problem(model):
    objective = -model.alpha * sum(
        log(
            sum(
                model.P_s_nominal[a, sp] * exp(
                    -model.pi_s[a] * (model.R_s[a, sp] + model.gamma * model.v[sp]) / model.alpha)
                for sp in model.S
            )
        )
        for a in model.A
    ) - model.alpha * model.kappa
    return objective


def solve_policy(v: np.ndarray, P_nominal: np.ndarray, R: np.ndarray, gamma: float, kappa: float):
    S, A = P_nominal.shape[0], P_nominal.shape[1]
    assert P_nominal.shape == R.shape == (S, A, S) and v.shape == (S,)
    pi = np.zeros(shape=(S, A))
    v_new = np.zeros_like(v)
    for s in range(S):
        model = ConcreteModel()
        # sets
        model.S = Set(initialize=range(S))
        model.A = Set(initialize=range(A))
        # params
        model.P_s_nominal = Param(model.A, model.S, initialize=lambda model, a, sp: P_nominal[s, a, sp])
        model.R_s = Param(model.A, model.S, initialize=lambda model, a, sp: R[s, a, sp])
        model.v = Param(model.S, initialize=lambda model, sp: v[sp])
        model.gamma = Param(initialize=gamma)
        model.kappa = Param(initialize=kappa)
        # vars
        model.alpha = Var(domain=NonNegativeReals, initialize=1.0, bounds=(None, 1e10))
        model.pi_s = Var(model.A, domain=NonNegativeReals, initialize=1 / A)
        # objective function
        model.objective = Objective(rule=kl_dual_problem, sense=maximize)
        # constraint
        model.constraint_pi = Constraint(expr=sum(model.pi_s[a] for a in model.A) == 1)
        # solve
        results = solver.solve(model, tee=False)
        v_new[s] = value(model.objective)
        for a in model.A:
            pi[s, a] = value(model.pi_s[a])
    return pi


def compute_q_sa(r_sa: np.ndarray, P_sa_nominal: np.ndarray, gamma: float, v: np.ndarray,
                 kappa: float, divergence: str, k: float = 2.0):
    model = ConcreteModel()
    model.S = Set(initialize=np.where(P_sa_nominal != 0)[0].tolist())
    model.P_sa_nominal = Param(model.S, initialize=lambda model, sp: P_sa_nominal[sp])
    model.R_sa = Param(model.S, initialize=lambda model, sp: r_sa[sp])
    model.v = Param(model.S, initialize=lambda model, s: v[s])
    model.gamma = Param(initialize=gamma)
    model.P_sa = Var(model.S, domain=NonNegativeReals, initialize=model.P_sa_nominal)
    if divergence == 'f':
        model.k = Param(initialize=k)
    model.objective = Objective(
        expr=sum(model.P_sa[sp] * (model.R_sa[sp] + model.gamma * model.v[sp]) for sp in model.S), sense=minimize)
    if divergence == 'kl':
        model.constraint_divergence = Constraint(expr=kl_divergence(model) <= kappa)
    elif divergence == 'f':
        model.constraint_divergence = Constraint(expr=f_divergence(model) <= kappa)
    model.constraint_P = Constraint(rule=constraint_P)
    # ================================================================== #
    # 🚀 Solve P_sa
    # ================================================================== #
    solver.solve(model, tee=False)
    objective_value = value(model.objective)
    return objective_value


def robust_value_iteration(P: np.ndarray, r: np.ndarray, gamma: float, v: np.ndarray, kappa: float,
                           divergence: str, k: float = 2.0,
                           max_iteration=1000, eps=1e-8):
    assert len(P.shape) == 3
    S, A = P.shape[0], P.shape[1]
    assert r.shape == (S, A, S)
    assert v.shape == (S,)
    assert 0 < gamma < 1
    assert divergence in ['f', 'kl']
    print(S, A)
    Q = np.zeros(shape=(S, A))
    for i in range(max_iteration):
        start = time.time()
        for s in range(S):
            for a in range(A):
                q_sa = compute_q_sa(r[s, a], P[s, a], gamma, v, kappa, divergence, k)
                Q[s, a] = q_sa
        v_new = np.max(Q, axis=1)
        v_diff = np.linalg.norm(v_new - v)
        end = time.time()
        logging.info(f"step{i}, v_diff: {v_diff}, curr_v={v_new}, time: {end - start}")
        if v_diff < 10 * eps:
            break
        v = v_new
        if i == max_iteration - 1:
            print("reach max iteration")
    pi = np.zeros_like(Q)
    pi[np.arange(S),np.argmax(Q, axis=1)] = 1
    return v, pi


class MDP:
    def __init__(self, P: np.ndarray, r: np.ndarray, gamma: float,
                 kappa: float, divergence: str, k: float = 2.0,
                 max_iteration=1000, eps=1e-8):
        assert len(P.shape) == 3
        S, A = P.shape[0], P.shape[1]
        assert P.shape == (S, A, S) and r.shape == (S, A, S)
        assert kappa >= 0
        assert divergence in {'kl', 'f'}
        assert k != 0 and k != 1
        assert 0 < gamma < 1

        self.P = P
        self.r = r
        self.v = np.zeros(S)
        self.gamma = gamma
        self.kappa = kappa
        self.divergence = divergence
        self.k = k
        self.max_iteration = max_iteration
        self.eps = eps

    def solve(self, rectangular='s'):
        if rectangular == 's':
            v, xi = bisection(self.P, self.r, self.gamma, self.v, self.kappa,
                              self.divergence, self.k, self.max_iteration, self.eps)
            pi = None
            try:
                pi = solve_policy(v, self.P, self.r, self.gamma, self.kappa)
            except Exception as e:
                print(e)
            self.xi = xi
            return v, pi
        elif rectangular == 'sa':
            v, pi = robust_value_iteration(self.P, self.r, self.gamma, self.v, self.kappa,
                                           self.divergence, self.k, self.max_iteration, self.eps)
            self.xi = None
            return v, pi
        else:
            raise ValueError('rectangular')
