import wandb
import argparse

from envs.make_env import *
from utils import *

from algos.grpo import GRPO
from sb3_contrib import RecurrentPPO
from stable_baselines3 import DQN, PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.utils import set_random_seed


parser = argparse.ArgumentParser()
parser.add_argument('--env', default="tmaze-v0", help="Environment name")
parser.add_argument('--algo', default="PPO", help="RL learning algorithm: DQN, PPO, RecurrentPPO")
parser.add_argument('--mask_type', default="fully_obs", help="fully_obs, no_stack, framestack, masked, ca_masked, ca_all_masked, demir")
parser.add_argument('--cube_cam', default="orthographic", help="full, face, orthographic")
parser.add_argument('--scramble_steps', type=int, default=5, help="Scramble steps for cube env")
parser.add_argument('--maze_length', type=int, default=1, help="Maze length for tmaze")
parser.add_argument("--random_length", help="", action='store_true', default=False)
parser.add_argument('--active', action='store_true', default=False, help="Active tmaze mode")
parser.add_argument('--continual', action='store_true', default=False, help="Continual setting")
parser.add_argument('--visible_goal_steps', type=int, default=2, help="Number of steps where the environment goal is visible in GCRL tasks")
parser.add_argument('--max_episode_steps', type=int, default=50, help="Max number of steps per episode")
parser.add_argument('--num_stack', type=int, default=1, help="Memory length (sequence length)")
parser.add_argument('--maxiter', type=int, default=1e6, help="Max training timesteps")
parser.add_argument('--features_dim', type=int, default=256, help="Input dim of policy layer")
parser.add_argument('--hidden_size', type=int, default=128, help="Hidden dim of memory architecture layer")
parser.add_argument('--run', type=int, default=None, help="Random seed / run id")
parser.add_argument('--nenvs', type=int, default=1, help="Number of envs/processes")
parser.add_argument('--path', default="./data/", help="Save path for logs and models")
parser.add_argument('--device', default="cuda", help="Device for Pytorch")
parser.add_argument('--arch', choices=['cnn', 'mlp', 'transformer', 'lstm'], default='mlp', help="Policy architecture")
parser.add_argument('--render_mode', default='rgb_array', help="Render mode")
args = parser.parse_args()
assert args.mask_type in ["fully_obs", "no_stack", "framestack", "masked", "ca_masked", "all_masked", "ca_all_masked", "all_history_masked", "ca_all_history_masked", "demir"], \
       'mask_type not in allowed list'


if __name__ == "__main__":
    # Instantiate envs
    env, name = make_env(args, 0)
    save_path = args.path + name

    print("Observation space: ", env.observation_space)
    print("Action space: ", env.action_space)
    print("save_path", save_path)
    wandb.init(project="anonymised", entity="anonymised", name=name, mode='disabled')

    env = LoggerWrapper(env, args.mask_type)
    vec_env = env
    if args.nenvs>1:
        def make_env_(rank: int, seed: int = 0):
            def _init():
                if rank==0: env_ = env
                else:       env_ = make_env(args, rank)[0]
                env_.reset(seed=args.run + rank)
                return env_
            set_random_seed(seed)
            return _init
        vec_env = SubprocVecEnv([make_env_(i) for i in range(args.nenvs)])

    # Instantiate and train model
    policy_type, policy_kwargs = get_policy_type(args)
    if args.algo == "DQN":
        if policy_type == "CnnPolicy":
            model = DQN(policy_type, env, policy_kwargs=policy_kwargs, device=args.device,
                        verbose=1, tensorboard_log=save_path)
        else:
            model = DQN(policy_type, env, policy_kwargs=policy_kwargs, device=args.device,
                        learning_rate=0.001, gamma=0.99, learning_starts=10000,
                        target_update_interval=1000, train_freq=1,
                        exploration_fraction=0.5, exploration_final_eps=0.1,
                        verbose=1, tensorboard_log=save_path)
        model.learn(int(args.maxiter), callback=SaveLogCallback(env,log_dir=save_path))
    elif args.algo == "PPO":
        model = PPO(policy_type, vec_env, policy_kwargs=policy_kwargs, device=args.device,
                    n_steps=128, batch_size=128,
                    verbose=1, tensorboard_log=save_path)
        model.learn(int(args.maxiter), callback=SaveLogCallback(env,log_dir=save_path))
    elif args.algo == "RecurrentPPO":
        model = RecurrentPPO(policy_type, vec_env, policy_kwargs=policy_kwargs, device=args.device,
                    n_steps=128, batch_size=128,
                    verbose=1, tensorboard_log=save_path)
        model.learn(int(args.maxiter), callback=SaveLogCallback(env,log_dir=save_path))
    elif args.algo == "GRPO":
        model = GRPO(policy_type, vec_env, policy_kwargs=policy_kwargs, device=args.device,
                    n_steps=128, batch_size=128,
                    verbose=1, tensorboard_log=save_path)
        model.learn(int(args.maxiter), callback=SaveLogCallback(env,log_dir=save_path))
    else:
        raise ValueError(f"Unsupported algorithm: {args.algo}")
