import numpy as np
import argparse
import os
from env.cliffwalking import CliffWalkingEnv
import scipy.optimize as opt

LOG_EPS = 1e-8


def drchi_operator(P, V, discount, radius):
    res = opt.fminbound(
        lambda eta: -eta + np.sqrt(1 + radius) * np.sqrt(np.dot(P,
                                                                np.clip(eta - V, a_max=None, a_min=0)**2)),
        -1 / (1 - discount),
        1 / (1 - discount),
        full_output=True,
    )
    return res[0], -res[1]


def drkl_operator(P, V, discount, radius):
    res = opt.fminbound(
        lambda b: b * np.log(np.dot(P, np.exp(-(V - V.min()) / b)) + LOG_EPS) + b * radius - V.min(),
        1e-1,
        1e2,
        full_output=True,
    )
    return res[0], -res[1]


def nonrobust_operator(P, V, *args):
    return 0, P.dot(V)


def drrenyi_operator(P, V, discount, radius, renyi_k):
    k = renyi_k
    C = (1 + k * (k - 1) * radius)**(1 / k)
    k_star = k / (k - 1)

    res = opt.fminbound(
        lambda eta: -eta + C * np.dot(P,
                                      np.clip(eta - V, a_max=None, a_min=0)**k_star)**(1 / k_star),
        -1 / (1 - discount),
        1 / (1 - discount),
        full_output=True,
    )
    return res[0], -res[1]


def get_optimal_value(mdp, opertor, discount, *args):
    from itertools import product
    Q = np.zeros((mdp.N_S, mdp.N_A))
    B = np.zeros((mdp.N_S, mdp.N_A))
    Q_old = np.ones((mdp.N_S, mdp.N_A)) / (1 - discount)

    t = 0
    while np.abs(Q - Q_old).max() > 1e-4:
        Q_old = Q.copy()
        for s, a in product(range(mdp.N_S), range(mdp.N_A)):
            if s == mdp.terminal_state:
                Q[s, a] = 0
            else:
                res = opertor(
                    mdp.P[s, a],
                    mdp.r[s, a] + discount * Q_old.max(axis=-1),
                    discount,
                    *args,
                )
                Q[s, a] = res[1]
                B[s, a] = res[0]
        t += 1
    print("----------------Policy-----------------")
    print(Q[:-1].argmax(axis=-1).reshape(4, 4))
    # print("---------------------------------------")
    print("----------------Q-----------------")
    print(Q[12:-1])
    # print("---------------------------------------")
    value = np.dot(mdp.initial_state_dist, Q.max(axis=-1))
    dual_var = np.dot(mdp.initial_state_dist, B[np.arange(0, mdp.N_S), Q.argmax(axis=-1)])
    opt_v = Q.max(axis = -1)
    return value, dual_var, opt_v, Q


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--discount", default=0.9, type=float)  # Discount factor
    parser.add_argument("--radius", default=0.1, type=float)  # Radius of uncentainty set
    parser.add_argument("--type", default="Renyi", type=str)  # Radius of uncentainty set
    parser.add_argument("--renyi_k", default=2.0, type=float)  # Radius of uncentainty set
    args = parser.parse_args()

    mdp = CliffWalkingEnv()
    if args.type == "Chi":
        value, dual_var = get_optimal_value(mdp, drchi_operator, args.discount, args.radius)
    elif args.type == "KL":
        value, dual_var = get_optimal_value(mdp, drkl_operator, args.discount, args.radius)
    elif args.type == "Renyi":
        assert args.renyi_k > 1.
        value, dual_var, opt_v, Q = get_optimal_value(mdp, drrenyi_operator, args.discount, args.radius, args.renyi_k)
    elif args.type == "nonrobust":
        value, dual_var, opt_v, Q = get_optimal_value(mdp, nonrobust_operator, args.discount)
        
    if args.type == "Renyi":
        file_name = f"{args.type}({args.renyi_k})_{args.radius}_optimal"
        print("---------------------------------------")
        print(f"Type: {args.type}({args.renyi_k}), Radius: {args.radius}, Value: {value:.4f}, DualVar: {dual_var:.4f}")
        print("---------------------------------------")
    elif args.type == "nonrobust":
        file_name = f"{args.type}_optimal"
        print("---------------------------------------")
        print(f"Type: {args.type}, Value: {value:.4f}")
        print("---------------------------------------")
    else:
        file_name = f"{args.type}_{args.radius}_optimal"
        print("---------------------------------------")
        print(f"Type: {args.type}, Radius: {args.radius}, Value: {value:.4f}, DualVar: {dual_var:.4f}")
        print("---------------------------------------")

    if not os.path.exists("./results"):
        os.makedirs("./results")
    np.save(f"./results/{file_name}", value)
    np.save(f"./results/{file_name}_V", value)
    np.save(f"./models/{file_name}_Q", Q)