import wandb
import argparse
import os
import imageio
import matplotlib.pyplot as plt
from collections import defaultdict

from envs.make_env import *
from utils import *

from algos.grpo import GRPO
from sb3_contrib import RecurrentPPO
from stable_baselines3 import DQN, PPO
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines3.common.utils import set_random_seed


parser = argparse.ArgumentParser()
parser.add_argument('--env', default="tmaze-v0", help="Environment name")
parser.add_argument('--algo', default="PPO", help="RL learning algorithm: DQN, PPO, RecurrentPPO")
parser.add_argument('--mask_type', default="fully_obs", help="fully_obs, no_stack, framestack, masked, ca_masked, ca_all_masked, demir")
parser.add_argument('--cube_cam', default="orthographic", help="full, face, orthographic")
parser.add_argument('--scramble_steps', type=int, default=5, help="Scramble steps for cube env")
parser.add_argument('--maze_length', type=int, default=1, help="Maze length for tmaze")
parser.add_argument("--random_length", help="", action='store_true', default=False)
parser.add_argument('--active', action='store_true', default=False, help="Active tmaze mode")
parser.add_argument('--continual', action='store_true', default=False, help="Continual setting")
parser.add_argument('--visible_goal_steps', type=int, default=2, help="Number of steps where the environment goal is visible in GCRL tasks")
parser.add_argument('--fetchhideblock', action='store_true', default=False, help="Hide the block in the fetch tasks")
parser.add_argument('--max_episode_steps', type=int, default=50, help="Max number of steps per episode")
parser.add_argument('--num_stack', type=int, default=1, help="Memory length (sequence length)")
parser.add_argument('--maxiter', type=int, default=1e6, help="Max training timesteps")
parser.add_argument('--features_dim', type=int, default=256, help="Input dim of policy layer")
parser.add_argument('--hidden_size', type=int, default=128, help="Hidden dim of memory architecture layer")
parser.add_argument('--run', type=int, default=None, help="Random seed / run id")
parser.add_argument('--nenvs', type=int, default=1, help="Number of envs/processes")
parser.add_argument('--path', default="./data/", help="Save path for logs and models")
parser.add_argument('--device', default="cuda", help="Device for Pytorch")
parser.add_argument('--arch', choices=['cnn', 'mlp', 'transformer', 'lstm'], default='mlp', help="Policy architecture")
parser.add_argument('--render_mode', default='human', help="Render mode")
args = parser.parse_args()
assert args.mask_type in ["fully_obs", "no_stack", "framestack", "masked", "ca_masked", "all_masked", "ca_all_masked", "all_history_masked", "ca_all_history_masked", "demir"], \
       'mask_type not in allowed list'


def load_agent(algorithm, model_path, env, device="cpu", deterministic=True):
    if algorithm == "QL":
        Q = defaultdict(lambda: np.zeros(env.action_space.n))
        Q.update(np.load(model_path, allow_pickle=True).item())
        return lambda s: (Q[s].argmax(), Q[s].max())

    if algorithm == "PPO":
        model = PPO.load(model_path, env=env, device=device)
        def agent(s):
            a = model.predict(s, deterministic=deterministic)[0]
            s_t = torch.as_tensor(s).to(model.device).float().unsqueeze(0)
            with torch.no_grad():
                v = model.policy.predict_values(s_t)
            return a, float(v.cpu().numpy())
        return agent
    
    if algorithm == "RecurrentPPO":
        model = RecurrentPPO.load(model_path, env=env, device=device)
        def agent(s, state=None, episode_starts=None):
            a = model.predict(s, state=state, episode_start=episode_starts, deterministic=deterministic)[0]
            s_t = torch.as_tensor(s).to(model.device).float().unsqueeze(0)
            v = 0
            with torch.no_grad():
                v = model.policy.predict_values(s_t, lstm_states=lstm_states, episode_starts=episode_starts)
                v = float(v.cpu().numpy())
            return a, v
        return agent
    
    elif algorithm == "GRPO":
        model = GRPO.load(model_path, env=env, device=device)
        def agent(s):
            a = model.predict(s, deterministic=deterministic)[0]
            s_t = torch.as_tensor(s).to(model.device).float().unsqueeze(0)
            with torch.no_grad():
                v = model.policy.predict_values(s_t)
            return a, float(v.cpu().numpy())
        return agent

    raise ValueError(f"Unknown algorithm {algorithm!r}")


if __name__ == "__main__":
    # Instantiate envs
    env, name = make_env(args, 0)
    save_path = args.path + name
    model_path = save_path + ("-values.npy" if args.algo=="QL" else "_values")
    print(model_path)

    print("Observation space: ", env.observation_space)
    print("Action space: ", env.action_space)
    print("save_path", save_path)

    if not os.path.exists(model_path) and not os.path.exists(model_path+".zip"):
        print("FAILED ", model_path)

    gamma = 0.99
    images = []
    G, S, e = 0, 0, 0

    if args.algo == "RecurrentPPO":
        model = RecurrentPPO.load(model_path, env=env, device=args.device)
        # vec_env = model.get_env()
        # vec_env.envs[0].unwrapped._map_init()
        # vec_env.envs[0].unwrapped.map_img = vec_env.envs[0].unwrapped._gridmap_to_img() 
        # vec_env.envs[0].unwrapped.random_length = False
        # vec_env.envs[0].unwrapped.length = 100
        # print(vec_env.envs[0].unwrapped.length)
        # print(vec_env.envs[0].unwrapped.random_length)

        args.maze_length = 32
        args.random_length = False
        vec_env = DummyVecEnv([lambda: make_env(args, 0)[0]])

        obs = vec_env.reset()
        # Cell and hidden state of the LSTM
        lstm_states = None
        num_envs = 1
        e, t = 0, 0
        # Episode start signals are used to reset the lstm states
        episode_starts = np.ones((num_envs,), dtype=bool)
        while True:
            action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
            # Note: vectorized environment resets automatically
            obs, rewards, dones, info = vec_env.step(action)
            G += (gamma**t)*rewards[0]
            if dones[0]:
                S += rewards[0]>0
                e += 1
                t = 0
                print(e,S/e,G/e)
            episode_starts = dones
            vec_env.render("human")
    else:
        agent = load_agent(args.algo, model_path, env, args.device, deterministic=False)
        while True:
            e += 1
            state, _ = env.reset()
            lstm_states = None
            episode_starts = np.ones((1,), dtype=bool)
            for t in range(1000):
                # env.render()
                # plt.pause(0.001)
                if args.algo == "RecurrentPPO":
                    (action, lstm_states), value = agent(state, state=lstm_states, episode_starts=episode_starts)
                else:
                    action, value = agent(state)
                print(action, lstm_states, value)
                state, reward, done, truncate, _ = env.step(action) 
                G += (gamma**t)*reward
                if done or truncate:
                    S += reward>0
                    break
            # print(t,S/e,G/e)
            # env.render()
            # plt.pause(0.001)