from typing import Tuple, List, Optional
import numpy as np
from tqdm import tqdm
from pprint import pprint
from statistics import mean

original_board = [
    [0, 0, 0, 0, 0, 3],
    [0, 2, 2, 2, 0, 0],
    [0, 0, 0, 0, 0, 0],
    [0, 2, 2, 2, 0, 0],
    [0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0],
]
wind_probability = 0.9
start_position = (0, 0)
left_wind_position = [
    (0, 1),
    (0, 2),
    (0, 3),
    (2, 1),
    (2, 2),
    (2, 3),
    (4, 1),
    (4, 2),
    (4, 3),
    (5, 1),
    (5, 2),
    (5, 3),
]
row_coef = {0: 1, 2: 3, 4: 6, 5: 6}
bottom_wind_position = [
    (1, 5),
    (2, 5),
    (3, 5),
    (4, 5),
    # (5, 5),
]
row_coef_bottom = {0: 6, 1: 5, 2: 4, 3: 3, 4: 2, 5: 1}
# right is 0, down is 1, left is 2, up is 3
move_dict = {0: (0, 1), 1: (1, 0), 2: (0, -1), 3: (-1, 0)}


def generate_transition_matrix(
    agent_position: Tuple[int, int],
    original_board: List[List[int]],
    action: int,
    wind_probability: float,
    left_wind_position: List[Tuple[int, int]],
) -> List[List[float]]:
    """
    Generate the transition matrix for a given action

    agent_position: the current position of the agent
    original_board: the original board
    action: the action to take
    wind_probability: the probability of the wind blowing
    left_wind_position: the position of the left wind
    """
    transition_matrix = [[0.0 for _ in range(6)] for _ in range(6)]
    # get the new position
    transition_matrix[agent_position[0]][agent_position[1]] = 1
    # If state is in the wall or in the goal, return the transition matrix. Absorbant state
    if (
        original_board[agent_position[0]][agent_position[1]] == 2
        or original_board[agent_position[0]][agent_position[1]] == 3
    ):
        return transition_matrix

    # If action is right and the agent is in the right wall, return the transition matrix.
    # Absorbant state
    if action == 0 and agent_position[1] == 5:
        return transition_matrix

    # If action is down and the agent is in the bottom wall,
    # return the transition matrix. Absorbant state
    if action == 1 and agent_position[0] == 5:
        return transition_matrix

    # If action is left and the agent is in the left wall, return the transition matrix.
    # Absorbant state
    if action == 2 and agent_position[1] == 0:
        return transition_matrix

    # If action is up and the agent is in the top wall, return the transition matrix.
    # Absorbant state
    if action == 3 and agent_position[0] == 0:
        return transition_matrix

    # If the agent is in the left wind position, add the wind probability to the
    # left state and subtract it from the current state
    # Power of the wind is the row number
    if agent_position in left_wind_position:
        power = row_coef[agent_position[0]]
        transition_matrix[agent_position[0]][agent_position[1] - 1] += (
            wind_probability**power
        )
        transition_matrix[agent_position[0]][agent_position[1]] -= (
            wind_probability**power
        )

    # if agent_position in bottom_wind_position:
    #     power_bottom = row_coef_bottom[agent_position[0]]
    #     transition_matrix[agent_position[0] + 1][agent_position[1]] += (
    #         wind_probability**power_bottom
    #     )
    #     transition_matrix[agent_position[0]][agent_position[1]] -= (
    #         wind_probability**power_bottom
    #     )

    # Compute the next position of the agent given the action
    next_position_action = (
        agent_position[0] + move_dict[action][0],
        agent_position[1] + move_dict[action][1],
    )

    if original_board[next_position_action[0]][next_position_action[1]] != 2:
        transition_matrix[next_position_action[0]][
            next_position_action[1]
        ] += transition_matrix[agent_position[0]][agent_position[1]]
        transition_matrix[agent_position[0]][agent_position[1]] = 0

    return transition_matrix


def reward_function(state: int):
    if state == 5:
        # return 10
        return 0
    return -1


transition_matrix = []
for action in [0, 1, 2, 3]:
    next_states_tr = []
    for row in range(6):
        for col in range(6):
            tr = generate_transition_matrix(
                agent_position=(row, col),
                original_board=original_board,
                action=action,
                wind_probability=wind_probability,
                left_wind_position=left_wind_position,
            )
            tr_flatten: List[float] = sum(tr, [])
            next_states_tr.append(tr_flatten)
    transition_matrix.append(next_states_tr)


def generate_list_transition_matrix(nb_transition_matrix: int = 10):
    list_transition_matrix = []
    for i in range(nb_transition_matrix):
        transition_matrix = []
        for action in [0, 1, 2, 3]:
            next_states_tr = []
            for row in range(6):
                for col in range(6):
                    tr = generate_transition_matrix(
                        agent_position=(row, col),
                        original_board=original_board,
                        action=action,
                        wind_probability=i / nb_transition_matrix,
                        left_wind_position=left_wind_position,
                    )
                    tr_flatten: List[float] = sum(tr, [])
                    next_states_tr.append(tr_flatten)
            transition_matrix.append(next_states_tr)
        list_transition_matrix.append(transition_matrix)
    return list_transition_matrix


list_transition_matrix = generate_list_transition_matrix()


def assert_transition_matrix(transition_matrix: List[List[List[float]]]):
    for action in transition_matrix:
        for state in action:
            assert sum(state) - 1 < 1e-10


for transition_matrix in list_transition_matrix:
    assert_transition_matrix(transition_matrix)


def value_iteration(
    states: Tuple[int, ...],
    actions: Tuple[int, ...],
    rewards: Tuple[float, ...],
    transition_matrix: List[List[List[float]]],
    gamma: float = 0.95,
    max_iter: int = 10_000,
    delta: float = 1e-40,
):
    # Initialize Markov Decision Process model
    gamma = gamma  # discount factor
    # Transition probabilities per state-action pair
    probs = transition_matrix

    # Set value iteration parameters
    # max_iter = 10000  # Maximum number of iterations
    # delta = 1e-40  # Error tolerance
    V = [0 for _ in states]  # Initialize values
    pi = [None for _ in states]  # Initialize policy
    total_iter = 0
    list_v0_during_iterations = []
    # Start value iteration
    for i in range(max_iter):
        max_diff = 0  # Initialize max difference
        V_new = [-float("inf") for _ in states]  # Initialize values
        for s in states:
            # max_val =  - float("inf")
            for a in actions:
                # Compute state value
                val = rewards[s]  # Get direct reward
                for s_next in states:
                    val += probs[a][s][s_next] * (
                        gamma * V[s_next]
                    )  # Add discounted downstream values

                # Store value best action so far

                if val > V_new[s]:
                    V_new[s] = val
                    pi[s] = a  # type: ignore

            # V_new[s] = max_val  # Update value with highest value

            # Update maximum difference
        abs_v_vnext = [abs(v - vn) for v, vn in zip(V, V_new)]
        max_diff = max(abs_v_vnext)  # type: ignore

        # Update value functions
        V = V_new  # type: ignore
        list_v0_during_iterations.append(V[0])
        total_iter += 1
        # If diff smaller than threshold delta for all states, algorithm terminates
        if max_diff < delta:
            break

    return {
        "value": V,
        "policy": pi,
        "total_iter": total_iter,
        "list_v0": list_v0_during_iterations,
    }


def robust_value_iteration(
    sequence_transition_matrix: List[
        List[List[List[float]]]
    ],  # List of transition matrices [a][s][s']
    states: Tuple[int, ...],
    actions: Tuple[int, ...],
    rewards: Tuple[float, ...],
    gamma: float = 0.9,
    max_iter: int = 10_000,
    delta: float = 1e-400,
    robust_value: Optional[List[float]] = None,
):
    # Initialize Markov Decision Process model

    gamma = gamma  # discount factor
    # Transition probabilities per state-action pair
    probs_set = sequence_transition_matrix  # Set value iteration parameters
    # max_iter = 10000  # Maximum number of iterations
    delta = ((1 - gamma) / (4 * gamma)) * delta  # Error tolerance
    V = [0 for _ in range(len(states))]  # Initialize values
    pi = [None for _ in range(len(states))]  # Initialize policy

    # Start value iteration
    list_differences_robust_value = []
    list_differences_v_s0 = []
    for i in range(max_iter):
        max_diff = 0  # Initialize max difference
        V_new = [-float("inf") for _ in range(len(states))]  # Initialize values
        for s in states:
            for a in actions:
                # Compute state value
                # V(s) = max_a min_p [ r(s,a) + gamma * sum_s' p(s'|s,a) V(s') ]
                min_val = float("inf")
                for probs in probs_set:
                    val = rewards[s]  # Get direct reward
                    for s_next in states:
                        val += probs[a][s][s_next] * (
                            gamma * V[s_next]
                        )  # Add discounted downstream values
                    min_val = min(min_val, val)
                # max_val = max(V, min_val)  # type: ignore

                if V_new[s] < min_val:
                    V_new[s] = min_val
                    pi[s] = a  # type: ignore

        abs_v_vnext = [abs(v - vn) for v, vn in zip(V, V_new)]
        max_diff = max(abs_v_vnext)  # type: ignore
        # max_diff = max(max_diff, abs(V[s] - V_new[s]))

        # Update value functions
        V = V_new  # type: ignore
        if robust_value is not None:
            difference_v_pi_v_star = [
                abs(v_pi - v_star) for v_pi, v_star in zip(V, robust_value)
            ]
            mean_difference_v_pi_v_star = mean(difference_v_pi_v_star)
            list_differences_robust_value.append(mean_difference_v_pi_v_star)
            list_differences_v_s0.append(abs(V[0] - robust_value[0]))

        # If diff smaller than threshold delta for all states, algorithm terminates
        if max_diff < delta:
            break

    if robust_value is not None:
        return {
            "value": V,
            "policy": pi,
            "list_differences_robust_value": list_differences_robust_value,
            "list_differences_v_s0": list_differences_v_s0,
        }

    return {"value": V, "policy": pi}


def iwocs(
    sequence_transition_matrix: List[List[List[List[float]]]],
    states: Tuple[int, ...],
    actions: Tuple[int, ...],
    rewards: Tuple[float, ...],
    gamma: float = 0.9,
    max_iter: int = 10_000,
    delta: float = 1e-400,
    robust_value: Optional[List[float]] = None,
    nb_iteration: int = 10,
):
    list_pi = []
    list_value = []
    list_word_idx = []
    worst_mdp = sequence_transition_matrix[0]
    v_bar = [0 for _ in range(len(states))]

    # Start value iteration
    list_differences_robust_value = []
    list_differences_v_s0 = []
    list_differences_v_vi_v_star = []
    list_final_v0 = []
    info_v0_s0 = {}

    total_iteration = 0
    for i in tqdm(range(nb_iteration)):
        if robust_value is not None:
            previous_error_v_bar_pi = [
                abs(v_pi - v_star) for v_pi, v_star in zip(v_bar, robust_value)
            ]

            # request infinite norm to fix the name
            # mean_difference_v_pi_v_star = mean(previous_error_v_bar_pi)
            mean_difference_v_pi_v_star = max(previous_error_v_bar_pi)
            difference_v_s0 = abs(v_bar[0] - robust_value[0])

        max_iter_algo = max_iter
        # if i == 0:
        # max_iter_algo = 90

        info_v0_s0[total_iteration] = difference_v_s0
        output = value_iteration(
            states=states,
            actions=actions,
            rewards=rewards,
            transition_matrix=worst_mdp,
            gamma=gamma,
            max_iter=max_iter_algo,
            delta=delta,
        )

        total_iteration += output["total_iter"]

        pi = output["policy"]
        value = output["value"]
        vi_iter = output["total_iter"]
        list_v0_vi = output["list_v0"]
        if robust_value is not None:
            for i in range(vi_iter):
                list_differences_v_vi_v_star.append(
                    abs(list_v0_vi[i] - robust_value[0])
                )
                list_differences_robust_value.append(mean_difference_v_pi_v_star)
                list_differences_v_s0.append(difference_v_s0)

        list_pi.append(pi)
        list_value.append(value)

        pi_bar, v_bar, argmin_list_v = build_pi_bar_and_vbar(list_pi, list_value)

        worst_mdp = None  # type: ignore
        worst_performance = float("inf")
        worst_idx = None  # type: ignore
        for idx_mdp, transition_matrix in enumerate(sequence_transition_matrix):
            reward_mdp = [
                monte_carlo_pi_evaluation(
                    pi_bar,
                    transition_matrix,
                    0,
                    reward_function,
                    len(states),
                    terminal_states=5,
                )
                for _ in range(300)
            ]
            mean_reward = np.mean(reward_mdp)
            if mean_reward < worst_performance:
                worst_mdp = transition_matrix
                worst_performance = mean_reward
                worst_idx = idx_mdp
        list_word_idx.append(worst_idx)

    # Monitor the error of the robust value
    if robust_value is not None:
        return {
            "value": v_bar,
            "policy": pi_bar,
            "argmin": argmin_list_v,
            "worst_mdp_idx": list_word_idx,
            "list_differences_robust_value": list_differences_robust_value,
            "list_differences_v_s0": list_differences_v_s0,
            "list_differences_v_vi_v_star": list_differences_v_vi_v_star,
            "info_v0_s0": info_v0_s0,
        }

    return {
        "value": v_bar,
        "policy": pi_bar,
        "argmin": argmin_list_v,
        "worst_mdp_idx": list_word_idx,
    }


def build_pi_bar_and_vbar(list_pi, list_value):
    new_pi_bar = [None for _ in range(len(list_pi[0]))]
    new_v_bar = [float("inf") for _ in range(len(list_value[0]))]
    argmin_list_v = [None for _ in range(len(list_value[0]))]
    for idx_psi, value_psi in enumerate(list_value):
        for i in range(len(value_psi)):
            if value_psi[i] < new_v_bar[i]:
                argmin_list_v[i] = idx_psi
                new_v_bar[i] = value_psi[i]
                new_pi_bar[i] = list_pi[idx_psi][i]
    return new_pi_bar, new_v_bar, argmin_list_v


def monte_carlo_pi_evaluation(
    pi, transition_matrix, inital_state, reward_function, nb_states, terminal_states
):
    state = inital_state
    reward_episode = 0
    for _ in range(10000):
        state = np.random.choice(nb_states, p=transition_matrix[pi[state]][state])
        reward_episode += reward_function(state)
        if state == terminal_states:
            break
    return reward_episode


def render(policy: List[Optional[int]], goal: int):
    render = {0: "→", 1: "↓", 2: "←", 3: "↑", None: "X"}
    policy_render = [render[i] for i in policy]
    policy_render[goal] = "G"
    policy_render_2d = [
        policy_render[i : i + 6] for i in range(0, len(policy_render), 6)
    ]
    return policy_render_2d


if __name__ == "__main__":
    # Generate list of transition matrices and keep only the first half \alpha \in [0, 0.5]
    list_transition_matrix_half = generate_list_transition_matrix(
        nb_transition_matrix=50
    )[:25]

    states = tuple([i for i in range(36)])
    rewards = [reward_function(state) for state in states]

    output = value_iteration(
        states=states,
        actions=(0, 1, 2, 3),
        rewards=tuple(rewards),
        transition_matrix=list_transition_matrix_half[0],
        gamma=0.95,
        max_iter=100_000,
        delta=1e-40,
    )

    print(output["value"])
    pprint(render(output["policy"], goal=5))

    robust_value_table = output["value"]

    output = robust_value_iteration(
        states=states,
        actions=(0, 1, 2, 3),
        rewards=tuple(rewards),
        sequence_transition_matrix=list_transition_matrix_half,
        gamma=0.95,
        delta=1e-3,
        robust_value=robust_value_table,
    )

    print(output["value"])
    pprint(render(output["policy"], goal=5))
    error_robust_vi = output["list_differences_v_s0"]

    output = iwocs(
        states=states,
        actions=(0, 1, 2, 3),
        rewards=tuple(rewards),
        sequence_transition_matrix=list_transition_matrix_half,
        gamma=0.95,
        delta=1e-30,
        robust_value=robust_value_table,
        nb_iteration=3,
    )

    print(output["value"])
    # error_ravi = output["list_differences_robust_value"]
    error_ravi = output["list_differences_v_s0"]
    # error_ravi = output["list_differences_v_vi_v_star"]
    worst_mdp_idx = output["worst_mdp_idx"]
    pprint(render(output["policy"], goal=5))
