import os
import argparse
import torch
import numpy as np
from modules import *
import yaml
from tqdm import tqdm
from envs.calvin_env import CalvinEnv


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--eval_episodes', type=int, default=1000)
    parser.add_argument('--algo', type=str, default='dtamp')
    parser.add_argument('--n_tasks', type=int, default=1)
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints/calvin/dtamp_calvin_s_0')
    parser.add_argument('--lmp_model_dir', type=str, default='checkpoints/calvin/play_lmp_calvin_s_0')
    args = parser.parse_args()

    config = yaml.load(open(f'configs/calvin/{args.algo}.yaml'), Loader=yaml.FullLoader)

    env = CalvinEnv(calvin_dir=config['evaluation_cfg']['calvin_dir'],
                    data_dir=config['dataset_cfg']['data_dir'])
    env.prepare_tasks(args.n_tasks)
    env.order_rollouts()

    use_skill = config['evaluation_cfg']['use_skill']
    if use_skill:
        lmp_config = yaml.load(open('configs/calvin/play_lmp.yaml'), Loader=yaml.FullLoader)
        lmp_model = PlayLMP(**lmp_config['model_cfg'], dataset_cfg=None)
        lmp_checkpoint = torch.load(os.path.join(args.lmp_model_dir, 'checkpoint_best.pt'))
        lmp_model.load_state_dict(lmp_checkpoint['model'])

    if args.algo == 'dtamp':
        model = DTAMP(**config['model_cfg'], dataset_cfg=None)
    else:
        raise NotImplementedError

    checkpoint_pth = os.path.join(args.checkpoint_dir, 'checkpoint.pt')
    checkpoint = torch.load(checkpoint_pth)
    model.load_state_dict(checkpoint['model'])
    scores = np.zeros(args.eval_episodes)

    pbar = tqdm(range(args.eval_episodes), desc='Average score: 0.000')
    for i in pbar:
        obs, goal = env.reset(task_id=i)
        milestones = model.planning(obs, goal)[1:]
        done = False
        timestep = 0
        score = 0
        while not done:
            if not use_skill:
                act = model.get_action(obs, milestones[:1])
                obs, rew, done, _ = env.step(act)
                score += rew
                timestep += 1
                if len(milestones) > 1:
                    if model.compute_distance(obs, milestones[:1]) < config['evaluation_cfg']['threshold']\
                       or timestep > config['evaluation_cfg']['timelimit']:
                        milestones = milestones[1:]
                        timestep = 0
            else:
                skill = model.get_action(obs, milestones[:1])
                obs_list = [obs]
                for _ in range(config['evaluation_cfg']['skill_duration']):
                    act = lmp_model.decode_skill(obs_list, skill)
                    obs, rew, done, _ = env.step(act)
                    score += rew
                    obs_list.append(obs)
                    timestep += 1
                    if done:
                        break
                    if len(milestones) > 1:
                        if model.compute_distance(obs, milestones[:1]) < config['evaluation_cfg']['threshold'] \
                           or timestep > config['evaluation_cfg']['timelimit']:
                            milestones = milestones[1:]
                            timestep = 0
                            break
        scores[i] = score
        pbar.set_description(desc='Average score: %.3f' % np.mean(scores[:i + 1]))
    pbar.close()


if __name__ == '__main__':
    main()
