import wandb
import gym
import argparse
import sys

from train_behavioral_policy import train_agent
from configs.train_behavioral_policy_defaults import get_cfg_defaults

# setting path
sys.path.append('/data/home/linial04/d4mrl/')

import sim
import utils

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument('--no-wandb', action='store_true')
    parser.add_argument('--seed', type=int, default=11)

    parser.add_argument('--train-env', type=str)
    parser.add_argument('--eval-env', type=str)
    parser.add_argument('--transform_eval_env', action='store_true')

    parser.add_argument('--stacked-frames', type=int, default=0)
    parser.add_argument('--transform-list', nargs='+')
    args = parser.parse_args()

    if args.transform_list is None:
        transform_list = []
    else:
        transform_list = [(args.transform_list[i], utils.num(args.transform_list[i + 1])) for i in
                          range(0, len(args.transform_list) - 1, 2)]

    args = parser.parse_args()

    # Set train_env (could also include different xml)
    if args.transform_eval_env:
        eval_env, wandb_name = sim.get_transformed_env(args.eval_env, transform_list)
        train_env = gym.make(args.train_env)
    else:
        eval_env = gym.make(args.eval_env)
        train_env, wandb_name = sim.get_transformed_env(args.train_env, transform_list)

    # Added this for evaluating on env with hidden dims as well
    for transform in transform_list:
        if 'obs_hidden_dims' in transform[0]:
            eval_env, _ = sim.get_hidden_dims_env(eval_env, transform[1], '')
            break

    if args.stacked_frames > 0:
        train_env = utils.stack_frames(train_env, args.stacked_frames, flatten=True)
        eval_env = utils.stack_frames(eval_env, args.stacked_frames, flatten=True)
        wandb_name += f'_stacked_frames_{args.stacked_frames}'

    config_list = [
        'system.seed', args.seed,
        'wandb.enable', not args.no_wandb,
        'wandb.project_name', f'{args.eval_env}',
        'wandb.entity_name', 'mechanistic_offline_rl',
        'wandb.name', f'{wandb_name}_b_pi',
        'env.train_env', args.train_env,
        'env.eval_env', f'{args.eval_env}',
        'train.agent_path', f'create_data/trained_agents/{wandb_name}/',
        'simulator.transform_list', transform_list
    ]

    config = get_cfg_defaults(config_list=config_list)

    if config.wandb.enable:
        wandb.init(project=config.wandb.project_name, entity=config.wandb.entity_name, config={})

    train_agent(config, train_env, eval_env)
