from argparse import ArgumentParser

from gym.wrappers import RescaleAction
import dmc2gym
import numpy as np

from torch.utils.tensorboard import SummaryWriter

from mdlc_sac import MDLCSeq_Agent
from utils import make_checkpoint, make_name, ParseBoolean
from train_agent_single import train_agent_model_free



def train_agent_seq(agent, tasks, params):

    # track performance    
    name = make_name(params)
    writer = SummaryWriter(log_dir=f"{params['save_dir']}/{params['agent']}_seq_runs/{params['env']}/" + name)

    K = len(tasks)   
    if params['order'] is None:
        task_order = np.random.permutation(K) 
    else:
        task_order = params['order']

    
    for k, task_idx in enumerate(task_order):

        # set new task  
        task = tasks[task_idx]
        # train control policy on new task 
        params['task'] = params['task_names'][task_idx]
        print (f"===== training control policy on {params['task']} =====")
        agent.control_policy = train_agent_model_free(
            agent, task, params, writer=writer, log_interval=1000, gif_interval=50000
            ).control_policy
        # default policy replay updated by train_agent_model_free
        # update default policy  
        print (f"===== done. now updating default policy =====")
        agent.optimize_default_policy(
            params['n_default_updates'],
            state_filter=None,
            beta_start=params['beta_pretrain'],
            beta_max=params['beta'],
            beta_warmup=params['beta_warmup'],
            writer=writer,
            task_name=params['task']
        )
        # make checkpoint 
        make_checkpoint(agent, params, task=params['task'], timestep=None, mt=False)
        # reset control policy 
        if params['reset_control']:
            agent.reset_control()



def main():
    
    parser = ArgumentParser()
    parser.add_argument('--save_dir', type=str, default='.', help='directory to save results')
    parser.add_argument('--env', type=str, default='walker')
    parser.add_argument('--order', type=str, default=None, help='order of task idxs, separated by dashes')
    parser.add_argument('--agent', type=str, default='mdlc-sac', help='agent choice: mdlc-seq')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--reset_control', type=ParseBoolean, default=True)
    parser.add_argument('--use_obs_filter', dest='obs_filter', action='store_true')
    parser.add_argument('--update_every_n_steps', type=int, default=1)
    parser.add_argument('--n_default_updates', type=int, default=int(5e4))
    parser.add_argument('--n_random_actions', type=int, default=10000)
    parser.add_argument('--n_collect_steps', type=int, default=1000)
    parser.add_argument('--n_evals', type=int, default=1)
    parser.add_argument('--beta', type=float, default=0.01) 
    parser.add_argument('--beta_pretrain', type=float, default=0.75)
    parser.add_argument('--beta_warmup', type=float, default=0.1)
    parser.add_argument('--norm_vdo', type=ParseBoolean, default=False)
    parser.add_argument('--default_start', type=float, default=0.5)
    parser.add_argument('--experiment_name', type=str, default='')
    parser.add_argument('--make_gif', dest='make_gif', action='store_true')
    parser.add_argument('--save_model', dest='save_model', action='store_true')
    parser.add_argument('--prev_task', type=str, default=None)
    parser.add_argument('--control_vdo', type=ParseBoolean, default=False) 
    parser.add_argument('--default_vdo', type=ParseBoolean, default=False)
    parser.add_argument('--learned_asymmetry', type=ParseBoolean, default=False)
    parser.add_argument('--total_steps', type=int, default=int(1e7))
    parser.set_defaults(obs_filter=False)
    parser.set_defaults(save_model=False)

    args = parser.parse_args()
    params = vars(args)

    seed = params['seed']
    env_name = params['env']
    tasks = []
    if env_name == "walker":
        task_names = ['stand', 'walk', 'run']
    elif env_name == "cartpole":
        task_names = ['balance', 'swingup', 'balance', 'swingup_sparse']
    else:
        raise NotImplementedError("Unsupported environment.")
    tasks = [dmc2gym.make(domain_name=env_name, task_name=task, seed=seed) for task in task_names]
    tasks = [RescaleAction(task, -1, 1) for task in tasks]
    params['task_names'] = task_names

    params['order'] = [int(i) for i in params['order'].split('-')]
        

    state_dim = tasks[0].observation_space.shape[0]
    action_dim = tasks[0].action_space.shape[0]

    agent_dict = {
        "mdlc-seq": MDLCSeq_Agent
    }

    agent_kwargs = {
        "beta": params['beta'],
        "control_vdo": params['control_vdo'], 
        "default_vdo": params['default_vdo'],
        "target_entropy": None if params['agent'] == 'sac' else 0.0,
        "norm_vdo": params['norm_vdo'],
        "learned_asymmetry": params['learned_asymmetry']
    }
    agent = agent_dict[params['agent']](
        seed, state_dim, action_dim, **agent_kwargs
        )

    name_list = ["env", "task", "agent", "seed", "beta", "beta_pretrain", "beta_warmup"]
    name_list += ["control_vdo", "default_vdo", "total_steps"]
    params['name_list'] = name_list
    
    train_agent_seq(agent=agent, tasks=tasks, params=params)


if __name__ == '__main__':
    main()

