import os

import argparse
import torch
import numpy as np
from modules import *
import yaml
from envs.d4rl_env import GoalReachingD4rlEnv
from tqdm import tqdm


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--eval_episodes', type=int, default=100)
    parser.add_argument('--env', type=str, default='antmaze-medium-play-v2')
    parser.add_argument('--algo', type=str, default='dtamp')
    parser.add_argument('--checkpoint_dir', type=str, default=None)
    parser.add_argument('--multigoal', action='store_true', dest='multigoal', default=False)
    args = parser.parse_args()

    domain = args.env.split('-')[0]
    config = yaml.load(open(f'configs/{domain}/{args.algo}.yaml'), Loader=yaml.FullLoader)

    env = GoalReachingD4rlEnv(args.env)

    if args.algo == 'dtamp':
        model = DTAMP(**config['model_cfg'], dataset_cfg=config['dataset_cfg'])
    else:
        raise NotImplementedError

    checkpoint_pth = os.path.join(args.checkpoint_dir, 'checkpoint.pt')
    checkpoint = torch.load(checkpoint_pth)
    model.load_state_dict(checkpoint['model'])
    scores = np.zeros(args.eval_episodes)
    pbar = tqdm(range(args.eval_episodes), desc='Average score: 0.000')
    for i in pbar:
        obs, goal = env.reset()
        milestones = model.planning(obs, goal)[1:]
        done = False
        score = 0
        timestep = 0
        while not done:
            act = model.get_action(obs, milestones[:1])
            obs, rew, done, _ = env.step(act)
            score += rew
            timestep += 1
            if len(milestones) > 1:
                if model.compute_distance(obs, milestones[:1]) < config['evaluation_cfg']['threshold']\
                   or timestep > config['evaluation_cfg']['timelimit']:
                    milestones = milestones[1:]
                    timestep = 0
        scores[i] = score
        pbar.set_description(desc='Average score: %.3f' % (np.mean(scores[:i + 1])))


if __name__ == '__main__':
    main()
