import numpy as np
from arguments import get_args
from discrete_action_robots_modules.dqn_agent import DQNAgent
from discrete_action_robots_modules.fb_norm import FBAgent
from discrete_action_robots_modules.robots import FetchReach, FetchPush
from discrete_action_robots_modules.psm import PSMAgent
from discrete_action_robots_modules.laplace import LaplacianAgent
import random
import torch
import datetime
import wandb
import time
import os



def get_env_params(env):
    obs = env.reset()
    # close the environment
    params = {'obs': obs['observation'].shape[0],
            'goal': obs['desired_goal'].shape[0],
            'action': env.num_actions,
            }
    params['max_timesteps'] = env._max_episode_steps
    return params


def launch(args):
    env = FetchReach()
    # import pdb
    # pdb.set_trace()
    # set random seeds for reproduce
    env.seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    log_dir = 'results/' + args.agent + '/' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    # set up the logger
    wandb.init(project="discrete_action_robots", name=args.agent, config=args, dir=log_dir)

    # Ensure that step of 'train/*' use train/frame
    wandb.define_metric('train/frame')
    wandb.define_metric('eval/frame')
    wandb.define_metric('inf/frame')

    wandb.define_metric('train/*', step_metric='train/frame')
    wandb.define_metric('eval/*', step_metric='eval/frame')
    wandb.define_metric('inf/*', step_metric='inf/frame')
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    # get the environment parameters
    env_params = get_env_params(env)
    # create the agent to interact with the environment
    if args.agent == 'DQN':
        dqn_trainer = DQNAgent(args, env, env_params)
        dqn_trainer.learn()
    elif args.agent == 'FB':
        fb_trainer = FBAgent(args, env, env_params)
        fb_trainer.learn()
    elif args.agent == 'PSM':
        psm_trainer = PSMAgent(args, env, env_params)
        psm_trainer.learn()
    elif args.agent == 'LAPLACE':
        laplace_trainer = LaplacianAgent(args, env, env_params)
        laplace_trainer.learn()
    else:
        raise NotImplementedError()


if __name__ == '__main__':
    # get the params
    args = get_args()
    launch(args)
