import os
import random
from collections import defaultdict
import numpy as np
import uuid
import gymnasium as gym

from build_datasets import flatten, split_to_episodes


def dump_trajectories(savedir, trajectories, goal_pos):
    save_name = os.path.join(savedir, f"learning_history_{goal_pos}_{str(uuid.uuid4())}.npz")
    np.savez(
        save_name,
        states=np.array(trajectories["states"], dtype=float).reshape(-1, 1),
        actions=np.array(trajectories["actions"]).reshape(-1, 1),
        rewards=np.array(trajectories["rewards"], dtype=float).reshape(-1, 1),
        dones=np.int32(
            np.array(trajectories["terminateds"]) | np.array(trajectories["truncateds"])
        ).reshape(-1, 1),
        goal=np.array(goal_pos),
        returns=np.array(trajectories["returns"], dtype=float).reshape(-1, 1),
    )
    return save_name


def eval(env, Q):
    total_reward = 0

    done = False
    (flag, state), _ = env.reset()
    while not done:
        a = np.argmax(Q[flag, state,:])
        (flag, state), r, term, trunc, _ = env.step(a)

        done = term or trunc
        total_reward += r

    return total_reward


def q_learning(
    env: gym.Env,
    lr=0.01,
    discount=0.9,
    num_episodes=200,
    savedir=None,
    seed=None,
    return_history=False,
    eps_coef=0.7,
    random_data=False,
):
    trajectories = defaultdict(list)
    rng = np.random.default_rng(seed)
    Q = rng.uniform(size=(2, env.unwrapped.size * env.unwrapped.size, env.action_space.n))

    rewards_history = []
    episode_reward = 0

    eps = 1.0
    eps_diff = 1.0 / (0.9 * num_episodes)
    _, _ = env.reset(seed=seed)
    # for i in trange(1, num_steps + 1):
    for i in range(1, num_episodes + 1):
        (flag, state), _ = env.reset()
        rewards_history.append(episode_reward)
        episode_reward = 0
        discounted_episode_reward = 0
        gamma = 1.0
        term, trunc = False, False

        while not (term or trunc):
            if random.random() < eps or random_data:
                a = env.action_space.sample()
            else:
                a = np.argmax(Q[flag, state, :])

            (next_flag, next_state), r, term, trunc, _ = env.step(a)
            episode_reward += r
            discounted_episode_reward += gamma * r
            gamma *= discount

            # if term:
            #     Q[next_flag, next_state, :] = 0

            # Collect trajectories with exploratory actions
            if savedir is not None:
                trajectories["states"].append(state)
                trajectories["actions"].append(a)
                trajectories["rewards"].append(r)
                trajectories["terminateds"].append(term)
                trajectories["truncateds"].append(trunc)

            # Update Q-Table with new knowledge
            Q[flag, state, a] += lr * (
                r
                + (1 - (term or trunc)) * discount * np.max(Q[next_flag, next_state, :])
                - Q[flag, state, a]
            )

            state = next_state
            flag = next_flag

        trajectories["returns"].append(episode_reward)
        trajectories["disc_returns"].append(discounted_episode_reward)
        eps = max(0, eps - eps_diff)

    if savedir is not None:
        goal_pos = (env.unwrapped.key_pos, env.unwrapped.door_pos) if hasattr(env.unwrapped, "key_pos") else env.unwrapped.goal_pos
        save_name = dump_trajectories(
            savedir, trajectories, goal_pos
        )
        if random_data:
            with np.load(save_name, allow_pickle=True) as f:
                learning_history = {
                    "states": f["states"],
                    "actions": f["actions"],
                    "rewards": f["rewards"],
                    "dones": f["dones"],
                    "goal": f["goal"],
                    "returns": f["returns"],
                }
            hist_trajectories, _ = split_to_episodes(learning_history)

            sorted_indices = sorted(range(len(trajectories["disc_returns"])), key=lambda x: trajectories["disc_returns"][x])

            hist_trajectories = flatten([hist_trajectories[i] for i in sorted_indices], goal_pos)
            np.savez(
                save_name,
                states=np.array(hist_trajectories["states"], dtype=float).reshape(-1, 1),
                actions=np.array(hist_trajectories["actions"]).reshape(-1, 1),
                rewards=np.array(hist_trajectories["rewards"], dtype=float).reshape(-1, 1),
                dones=np.int32(hist_trajectories["dones"]).reshape(-1, 1),
                goal=np.array(goal_pos),
                returns=np.array(hist_trajectories["returns"], dtype=float).reshape(-1, 1),
            )

    if return_history:
        return Q, rewards_history
    return Q