# type: ignore[attr-defined]
from typing import Optional
import time
from enum import Enum

import typer
from rich import print
from gridworld import SimpleGridWorld, plot_state_action_function
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import random
from dqn import GymWrapper, DQN


def seed_everything(seed):
    """
    Set the seed for all the possible random number generators
    for global packages.
    :param seed:
    :return: None
    """
    seed = int(seed)
    random.seed(seed)
    np.random.seed(seed)


def optimal_policy_from_value(world, value):
    """
    Compute the optimal policy from the given value function.

    Args:
        world: The `GridWorld` instance for which the the policy should be
            computed.
        value: The value-function dictating the policy as table
            `[state: Integer] -> value: Float`

    Returns:
        The optimal (deterministic) policy given the provided arguments as
        table `[state: Integer] -> action: Integer`.
    """
    policy = np.array(
        [
            np.argmax([value[world.step(s, a)] for a in range(world.num_actions)])
            for s in range(world.num_states)
        ]
    )

    return policy


def value_iteration(p, rewards, discount, eps=1e-3, random=False):
    """
    Basic value-iteration algorithm to solve the given MDP.

    Args:
        p: The transition probabilities of the MDP as table
            `[from: Integer, to: Integer, action: Integer] -> probability: Float`
            specifying the probability of a transition from state `from` to
            state `to` via action `action` to succeed.
        reward: The reward signal per state as table
            `[state: Integer] -> reward: Float`.
        discount: The discount (gamma) applied during value-iteration.
        eps: The threshold to be used as convergence criterion. Convergence
            is assumed if the value-function changes less than the threshold
            on all states in a single iteration.

    Returns:
        The value function as table `[state: Integer] -> value: Float`.
    """
    n_states, n_actions, _ = p.shape
    v = np.zeros(n_states)

    # Setup transition probability matrices for easy use with numpy.
    #
    # This is an array of matrices, one matrix per action. Multiplying
    # state-values v(s) with one of these matrices P_a for action a represents
    # the equation
    #     P_a * [ v(s_i) ]_i^T = [ sum_k p(s_k | s_j, a) * v(s_K) ]_j^T
    p = [np.matrix(p[:, a, :]) for a in range(n_actions)]

    delta = np.inf
    total_iters = 0

    print(f"Finding {'random' if random else 'optimal'} value function...")

    i = (
        np.zeros((n_states,), dtype=np.int64),
        np.arange(n_states),
    )  # used to index into the q value
    state_idx = np.arange(n_states)
    while delta > eps:  # iterate until convergence
        # if total_iters == 100:
        #     breakpoint()
        v_old = v

        # compute state-action values (note: we actually have Q[a, s] here)
        q = discount * np.array([p[a] @ v for a in range(n_actions)])

        # set q[a, s] = 0 if (s, a) leads to terminal state
        # TODO: don't hardcode this!
        # q[:, :, 24] = 0
        if random:
            # compute state values
            action = np.random.randint(0, n_actions, size=(25,))
            reward = rewards[state_idx, action]
            q[action, i[0], i[1]] += reward
            v = np.mean(q, axis=0)[0]
        else:
            # compute state values
            action = np.argmax(q, axis=0)[0]
            reward = rewards[state_idx, action]
            q[action, i[0], i[1]] += reward
            # print('\n\n\n\n\n')
            # print(q)
            v = np.max(q, axis=0)[0]

        # compute maximum delta
        delta = np.max(np.abs(v_old - v))
        # print(delta)

        total_iters += 1

    print(f"Total iterations: {total_iters}")
    print("Value:\n", v.reshape((5, 5)))

    return v, q


app = typer.Typer(
    name="research_project",
    add_completion=False,
)


def add_row(df, row):
    return pd.concat([pd.DataFrame([row], columns=df.columns), df], ignore_index=True)


def simple_moving_average(data, window_size):
    """
    Apply simple moving average with specified window size on a 1D NumPy array.

    :param data: The input 1D NumPy array.
    :param window_size: The number of elements to consider for the moving average window.
    :return: An array of the smoothed values.
    """
    if window_size <= 1:
        return data

    # Create a weight vector for the moving average with equal weights
    weights = np.ones(window_size) / window_size

    # Use the convolve function to apply the moving average
    sma = np.convolve(data, weights, mode="valid")

    return sma


@app.command(name="")
def main() -> None:
    seed_everything(1)    
    world = SimpleGridWorld(5, debug=True)
    discount = 0.99

    value, Q = value_iteration(
        world.transitions, world.rewards, discount=discount, random=False
    )
    value_random, Q_random = value_iteration(
        world.transitions, world.rewards, discount=discount, random=True
    )

    pi_star = optimal_policy_from_value(world, value)
    print("Pi*\n", pi_star.reshape((5, 5)))

    onstep_pi = optimal_policy_from_value(world, value_random)
    print("Onestep-Pi\n", onstep_pi.reshape((5, 5)))

    ##########
    # REWARD SHAPING
    #########
    print("\n\n\n\n\n", "Reward shaping:")
    phi = 0.5 * (value + np.max(Q_random, axis=0)[0])
    rewards = world.rewards.copy()
    shapes = np.zeros_like(rewards)
    for s in range(world.num_states):
        for a in range(world.num_actions):
            # if world.is_terminal(s):
            #     continue
            shape = phi[world.step(s, a)] * discount - phi[s]
            # shape = (phi[s] - phi.min()) / 10
            # shape = phi[world.step(s, a)]
            # shape = phi[s]
            # print(world.index_to_point(s))
            shapes[s, a] += shape

    shapes /= np.std(shapes) * 2
    rewards += shapes
    # plot_state_action_function(world, rewards, title='Reward shaping')
    value, Q = value_iteration(
        world.transitions, rewards, discount=discount, random=False
    )
    # print(Q)
    # s = 0
    # print(world.state_index_to_point(0))

    reward_maxent = np.array([0.21469736, 0.51637294, 0.66777695, 0.80923597, 0.43704773,
        0.28808798, 0.32477309, 0.30944462, 0.85650858, 0.35318616,
        0.45513115, 0.45304895, 0.39785857, 0.85321309, 0.70343047,
        0.56956322, 0.54992234, 0.49894885, 0.34162235, 0.71665543,
        0.63246331, 0.60327772, 0.54534266, 0.46709917, 1.14929378])[:, None].repeat(repeats=4, axis=1) # cached results to save time
    ##########
    # DQN
    #########
    import torch

    # torch.seed(1)
    all_performance = {
        "DQN (True)": [],
        "DQN (Ours)": [],
        "DQN (MaxEnt)": [],
    }

    # TODO: make sure discount you use for DQN matches what you use for value iteration
    all_eval = pd.DataFrame(columns=["reward", "eval_returns", "losses"])
    for seed in range(50):
        seed_everything(seed)
        torch.manual_seed(seed)

        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

        for r_name, r in zip(["DQN (Ours)", "DQN (True)", "DQN (MaxEnt)"], [rewards, world.rewards, reward_maxent]):
            env = GymWrapper(
                world,
                train_reward=r,
                true_reward=world.rewards,
                terminal=world.num_states - 1,
                time_limit=100,
            )
            dqn = DQN(env.observation_space.n, env.action_space.n, learning_rate=1e-3)
            if torch.cuda.is_available():
                dqn = dqn.cuda()

            eval_returns, losses = dqn.run(
                env,
                max_epsilon=0.01,
                min_epsilon=0.01,
                decay=0,
                train_steps=50_000,
                update_target_every=500,
                tau=1,
            )
            print(len(eval_returns))
            all_eval = add_row(all_eval, [r_name, eval_returns, losses])
            all_performance[r_name].append(eval_returns)

            # print(all_eval)
            all_eval.to_csv(f"all_eval_{time.time()}.csv")

    # graph it! #

    from rliable import library as rly
    from rliable import metrics
    from rliable import plot_utils

    window_size = 10

    # algorithms = ['DQN (True)', 'DQN (Ours)', 'DQN (MaxEnt)']
    # algorithms = ["DQN (True)", "DQN (Ours)"]

    # Load ALE scores as a dictionary mapping algorithms to their human normalized
    # score matrices across all 200 million frames, each of which is of size
    # `(num_runs x num_games x 200)` where scores are recorded every million frame.
    def get_perf(arr):
        a = np.array(
            [
                simple_moving_average(np.array(i), window_size=window_size)
                for i in arr
            ]
        )
        k = np.array(a)[:, None]
        return k

    frames = np.arange(
        len(
            simple_moving_average(
                np.array(all_performance["DQN (True)"][0][:150]),
                window_size=window_size,
            )
        )
    )
    iqm = lambda scores: np.array(
        [metrics.aggregate_iqm(scores[..., frame]) for frame in range(scores.shape[-1])]
    )
    iqm_scores, iqm_cis = rly.get_interval_estimates(
        {i: get_perf(all_performance[i]) for i in all_performance},
        iqm,
        reps=50,
    )

    # colors = ["purple", "orange", "green"]  # Add more colors if needed
    colors = ["purple", "orange"]  # Add more colors if needed
    print(iqm_scores)
    i = 0
    for name in all_performance:
        print(name)
        plt.errorbar(
            frames,
            iqm_scores[name] + 100,
            iqm_cis[name] + 100,
            alpha=0.7,
            color=colors[i],
        )
        plt.plot(frames, iqm_scores[name] + 100, label=name, alpha=0.8, color=colors[i])
        i += 1

    # Adding labels and legend
    plt.xlabel("Number of Samples (x100)")
    plt.ylabel("Return")
    plt.legend()
    # plt.ylim([-0.1, 3.5])

    # Display the plot
    plt.show()
    plt.savefig("lol.png")
    # import matplotlib.pyplot as plt

    # # Plotting
    # epochs = range(1, 51)  # Assuming num_epochs is 10, adjust if needed

    # plt.plot(epochs, maxent_policy_rewards, label="maxent policy")
    # plt.plot(epochs, shaped_policy_rewards, label="shaped policy")
    # plt.axhline(y=optimal_reward, color="r", linestyle="--", label="optimal policy")

    # plt.xlabel("Epochs")
    # plt.ylabel("Performance")
    # plt.title("Performance of Policies Over Training Epochs")
    # plt.legend()
    # plt.show()


if __name__ == "__main__":
    app()
k = pd.read_csv("/home/ezipe/git/282r/all_eval_1701886015.6951904.csv")
dqn_vals = [
    eval((k[k["reward"] == "DQN (True)"]["eval_returns"]).to_numpy()[i])
    for i in range(3)
]
our_vals = [
    eval((k[k["reward"] == "DQN (Ours)"]["eval_returns"]).to_numpy()[i])
    for i in range(3)
]

all_performance = {"DQN (True)": dqn_vals, "DQN (Ours)": our_vals}
