import matplotlib.pyplot as plt, os
import wandb
import gym
import numpy as np
import argparse
from wrappers import *
from library import *

import envs
from minigrid.wrappers import ImgObsWrapper

#####################################################################################
parser = argparse.ArgumentParser()
parser.add_argument('--env', default="tmaze-v0", help="Environment name: tmaze-v0, xormaze-v0")
parser.add_argument('--algo', default="QL", help="RL learning algorithm")
parser.add_argument('--mask_type', default="fully_obs", help="fully_obs, no_stack, framestack, masked, ca_masked, ca_all_masked, demir")
parser.add_argument('--maze_length', type=int, default=1, help="Maze length")
parser.add_argument("--random_length", help="", action='store_true', default=False)
parser.add_argument("--intrinsic_rewards", help="", action='store_true', default=False)
parser.add_argument("--active", help="", action='store_true', default=False)
parser.add_argument('--num_stack', type=int, default=1, help="Memory length")
parser.add_argument('--maxiter', type=int, default=1e6, help="Max training timesteps")
parser.add_argument('--run', default=None, type=int, help="")
parser.add_argument('--path', default="./data/", help="")
args = parser.parse_args()
assert args.mask_type in ["demir", "fully_obs", "no_stack", "framestack", "masked", "ca_masked", "all_masked", "ca_all_masked", "all_history_masked", "ca_all_history_masked"], 'mask_type not in ["fully_obs", "no_stack", "framestack", "masked", "ca_masked", "all_masked", "ca_all_masked", "all_history_masked", "ca_all_history_masked"]'
os.makedirs(args.path, exist_ok=True)
#####################################################################################

if "tmaze" in args.env:
    env      = gym.make(args.env, length=args.maze_length, random_length=args.random_length, active=args.active, continual=True, fix_start=True, goal_obs=False, fully_obs=args.mask_type=="fully_obs")
    env_eval = gym.make(args.env, length=args.maze_length, random_length=args.random_length, active=args.active, continual=True, fix_start=True, goal_obs=False, fully_obs=args.mask_type=="fully_obs")
    maze_type = "active" if args.active else "passive" 
    name = "{}-env_{}_tmaze-v0_maze_length_{}-num_stack_{}-mask_type_{}-run_{}".format(args.algo,maze_type,args.maze_length,args.num_stack,args.mask_type,args.run)
elif "xormaze" in args.env:
    env      = gym.make(args.env, length=args.maze_length, random_length=args.random_length, active=args.active, continual=False, fix_start=True, goal_obs=False, fully_obs=args.mask_type=="fully_obs")
    env_eval = gym.make(args.env, length=args.maze_length, random_length=args.random_length, active=args.active, continual=False, fix_start=True, goal_obs=False, fully_obs=args.mask_type=="fully_obs")
    maze_type = "active" if args.active else "passive" 
    name = "{}-env_{}_xormaze-v0_maze_length_{}-num_stack_{}-mask_type_{}-run_{}".format(args.algo,maze_type,args.maze_length,args.num_stack,args.mask_type,args.run)
else:
    env = gym.make(args.env)
    env_eval = gym.make(args.env)
    name = "{}-env_{}-num_stack_{}-mask_type_{}-run_{}".format(args.algo,args.env,args.num_stack,args.mask_type,args.run)

if "MiniGrid" in args.env: env = ImgObsWrapper(env); env_eval = ImgObsWrapper(env_eval)

if args.mask_type=="framestack": env = FrameStack(env, args.num_stack); env_eval = FrameStack(env_eval, args.num_stack)
if "masked" in args.mask_type: env = MaskedFrameStack(env, args.num_stack); env_eval = MaskedFrameStack(env_eval, args.num_stack)
if args.mask_type=="demir": 
    name = "{}-env_{}_{}_maze_length_{}-num_stack_{}-mask_type_{}-intrinsic_rewards_{}-run_{}".format(args.algo,maze_type,args.env,args.maze_length,args.num_stack,args.mask_type,args.intrinsic_rewards,args.run)
    # name = "{}-env_{}-num_stack_{}-mask_type_{}-intrinsic_rewards_{}-run_{}".format(args.algo,args.env,args.num_stack,args.mask_type,args.intrinsic_rewards,args.run)
    env = DemirFrameStack(env, args.num_stack, intrinsic_rewards=args.intrinsic_rewards); env_eval = DemirFrameStack(env_eval, args.num_stack, intrinsic_rewards=args.intrinsic_rewards)

print("Observation space: ", env.observation_space)
print("action_space: ", env.action_space)
#####################################################################################

wandb.init(project="anonymised", entity="anonymised", name=name, mode='disabled')
save_path=args.path+name
print(save_path)

if args.algo=="QL":
    maxiter = args.maxiter
    mean_episodes = 10000
    gamma = 0.99
    alpha = 0.1
    exploration = "online" # online, offline, edecay
    epsilon = 0.1 if exploration=="online" else 1
    
    env = TupleObs(env)
    env_eval = TupleObs(env_eval)
    A,stats = Q_learning(env, env_eval, epsilon=epsilon, mask_type=args.mask_type, epsilon_type=exploration, alpha=alpha, gamma=gamma, maxiter=maxiter, save_path=save_path, p=True, mean_episodes=mean_episodes)
