
import sys

from ema_pytorch.post_hoc_ema import default

# use line-buffering for both stdout and stderr
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)

import os
import pathlib
import click
import hydra
from hydra.utils import instantiate
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import utils.utils as utils
from datasets import create_engine, eval_libero
import wandb
import json

task_suites = ['libero_10'] # libero_10 libero_goal libero_object libero_90

@click.command()
@click.option('-c', '--checkpoint',
              default='/path/to/check_point',
              required=True)
@click.option('-d', '--device', default='cuda:0')
@click.option('-p', '--dataset_path',
              default='/path/to/dataset/libero_10',
              required=False)

def main(checkpoint, device, dataset_path):
    device = device

    # load checkpoint
    checkpoint_path = utils.get_latest_checkpoint(checkpoint)
    state_dict = utils.load_state(checkpoint_path)

    final_epoch = 0

    print('autoloading based on saved parameters')
    cfg = state_dict['config']
    cfg = OmegaConf.create(cfg)

    seed = cfg.seed
    torch.manual_seed(seed) # set seed
    task_suites = [cfg.task.benchmark_name]

    model = instantiate(state_dict['config']['algo']['policy'])
    model.to('cuda:0')
    final_epoch = state_dict['epoch']
    print(f"final_epoch:{final_epoch}")
    print(next(model.parameters()).device)

    model.to(device)
    model.eval()

    model.load_state_dict(state_dict['model'])

    experiment_dir, _ = utils.get_experiment_dir(cfg, evaluate=True)
    os.makedirs(experiment_dir)
    result_path = os.path.join(
        experiment_dir, f"Eval_{final_epoch}"
    )
    print('Saving to:', result_path)
    print('Running evaluation...')

    if cfg.name == "libero":
        train_loader, agent = instantiate(cfg.task.dataset,
                                          dataset_path=dataset_path,)
        agent.set_policy(model)
        rollout_results = eval_libero(agent, result_path, num_episodes=1, seed=seed, task_suites=task_suites)
    else:
        # run eval
        env_runner = hydra.utils.instantiate(
            cfg.task.env_runner,
            output_dir=result_path,
        )
        runner_log = env_runner.run(model)

        # dump log to json
        json_log = dict()
        for key, value in runner_log.items():
            if isinstance(value, wandb.sdk.data_types.video.Video):
                json_log[key] = value._path
            else:
                json_log[key] = value
        out_path = os.path.join(result_path, 'eval_log.json')
        json.dump(json_log, open(out_path, 'w'), indent=2, sort_keys=True)


if __name__ == '__main__':
    main()
