"""
Usage:
python eval.py --checkpoint data/image/pusht/diffusion_policy_cnn/train_0/checkpoints/latest.ckpt -o data/pusht_eval_output
"""

import sys
# use line-buffering for both stdout and stderr
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)

import os
import pathlib
import click
import hydra
import torch
import dill
import wandb
import json
from diffusion_policy.workspace.base_workspace import BaseWorkspace
from omegaconf import OmegaConf # cw
import pdb

@click.command()
@click.option('-c', '--checkpoint', required=True)
@click.option('-o', '--output_dir', required=True)
@click.option('-e', '--noise', default=0.0)
@click.option('-p', '--perturb', default=0.0)
@click.option('-ah', '--ahorizon', default=8)
@click.option('-t', '--ntest', default=200)
@click.option('-s', '--sampler', required=True)
@click.option('-n', '--nsample', default=1)
@click.option('-m', '--nmode', default=1)
@click.option('-k', '--decay', default=0.9)
@click.option('-r', '--reference', default=None)
@click.option('-d', '--seed', default=0)

@click.option('-dev', '--device', default='cuda:0')
@click.option('-timesteps', '--timesteps', default=100)
@click.option('-alpha', '--alpha', default=0.0)
@click.option('-beta', '--beta', default=1.0)




def main(checkpoint, output_dir, noise, perturb, ahorizon, ntest, sampler, nsample, nmode, decay, reference, seed, device, alpha, timesteps, beta): # cw
    if os.path.exists(output_dir):
        print(f"Output path {output_dir} already exists and will be overwrited.")
    pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

    # load reference
    if reference:
        try:

            payload = torch.load(open(reference, 'rb'), pickle_module=dill)
            cfg = payload['cfg']
            cls = hydra.utils.get_class(cfg._target_)
            workspace = cls(cfg, output_dir=output_dir)

            workspace: BaseWorkspace
            workspace.load_payload(payload, exclude_keys=None, include_keys=None)
            weak = workspace.model
            if cfg.training.use_ema:
                weak = workspace.ema_model
            weak.n_action_steps = ahorizon
            device = torch.device(device)
            weak.to(device)
            weak.eval()
            print('Loaded weak model')
        except Exception as e:
            weak = None
            print('Skipped weak model')
            
    print('\n' + '='*30)
    print('         Experiment Info')
    print('='*30)

    print(f'{"Seed":15s}: {seed}')
    print(f'{"CUDA Device":15s}: {device}')

    print(f'{"Num Episodes":15s}: {ntest}')
    print(f'{"Noise Level":15s}: {noise}')
    print(f'{"Perturbation":15s}: {perturb}')

    print(f'{"Weight (SG)":15s}: {alpha}')
    print(f'{"Tau (AC)":15s}: {beta}')

    print('='*30 + '\n')


    payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
    cfg = payload['cfg']
    OmegaConf.set_struct(cfg, False)
    cfg['task']['env_runner']['n_test'] = ntest
    cfg['task']['env_runner']['n_obs_steps'] = 8 
    


    # import pdb; pdb.set_trace()
    # DDIM-30 Solver
    cfg['policy']['num_inference_steps'] = 30 
    cfg['policy']['alpha'] = alpha 
    cfg['policy']['timesteps'] = timesteps 
    cfg['policy']['noise_scheduler'] = {
    '_target_': 'diffusers.schedulers.scheduling_ddim.DDIMScheduler',
    'num_train_timesteps': 100, 
    'beta_start': 0.0001,
    'beta_end': 0.02,
    'beta_schedule': 'squaredcos_cap_v2',
    'clip_sample': True,
    'set_alpha_to_one': True,
    'steps_offset': 0,
    'prediction_type': 'epsilon' 
    }
    
    cls = hydra.utils.get_class(cfg._target_)
    workspace = cls(cfg, output_dir=output_dir)
    workspace: BaseWorkspace
    workspace.load_payload(payload, exclude_keys=None, include_keys=None)
    
    # get policy from workspace
    policy = workspace.model
    if cfg.training.use_ema:
        policy = workspace.ema_model
    
    device = torch.device(device)
    policy.to(device)
    policy.eval()

    # turn off video
    cfg.task.env_runner['n_train_vis'] = 0
    cfg.task.env_runner['n_test_vis'] = 0
    cfg.task.env_runner['n_test'] = ntest
    cfg.task.env_runner['n_action_steps'] = ahorizon
    cfg.task.env_runner['test_start_seed'] = 20000 + 10000 * seed
    policy.n_action_steps = ahorizon

    print("\nEvaluation setting:")
    print("env:")
    try:
        print(f"  delay_horizon = {cfg.task.env_runner.n_latency_steps: <19} act_horizon = {cfg.task.env_runner.n_action_steps: <15} obsv_horizon = {cfg.task.env_runner.n_obs_steps}")
    except Exception as e:
        print(f"  delay_horizon = {float('nan'): <19} act_horizon = {cfg.task.env_runner.n_action_steps: <15} obsv_horizon = {cfg.task.env_runner.n_obs_steps}")
    print("policy:")
    print(f"  pred_horizon = {policy.horizon: <20} act_horizon = {policy.n_action_steps: <15} obsv_horizon = {policy.n_obs_steps}")




    # run eval
    if perturb > 0:
        env_runner = hydra.utils.instantiate(
            cfg.task.env_runner,
            output_dir=output_dir,
            max_steps=400, #400
            perturb_level=perturb,
            beta = beta)
    else:
        env_runner = hydra.utils.instantiate(
            cfg.task.env_runner,
            output_dir=output_dir,
            beta = beta)

    # set sampler
    env_runner.set_sampler(sampler, nsample, nmode, noise, decay)

    if reference and weak:
        env_runner.set_reference(weak)

    runner_log = env_runner.run(policy)

    print(f"train: {runner_log['train/mean_score']}         test: {runner_log['test/mean_score']}")

    # dump log to json
    json_log = dict()
    json_log['checkpoint'] = checkpoint
    for key, value in runner_log.items():
        if 'video' not in key:
            json_log[key] = value

    if sampler == 'random':
        out_path = os.path.join(output_dir, f'{sampler}_seed={seed}_perturb={perturb}_noise={noise}_{policy.horizon}-{ahorizon}.json') 

    elif sampler == 'ema':
        out_path = os.path.join(output_dir, f'{sampler}_seed={seed}_perturb={perturb}_noise={noise}_lambda={beta}_{policy.horizon}-{ahorizon}.json')

    elif sampler == 'bid':
        out_path = os.path.join(output_dir, f'{sampler}_seed={seed}_perturb={perturb}_noise={noise}_{policy.horizon}-{ahorizon}_{nsample}_{nmode}_{decay}.json')   

    elif sampler == 'ours':
        out_path = os.path.join(output_dir, f'{sampler}_seed={seed}_perturb={perturb}_noise={noise}_weight={alpha}_tau={beta}_{policy.horizon}-{ahorizon}.json')

    elif sampler == 'ac':
        out_path = os.path.join(output_dir, f'{sampler}_seed={seed}_perturb={perturb}_noise={noise}_weight={alpha}_tau={beta}_{policy.horizon}-{ahorizon}.json')

    elif sampler == 'sg':
        out_path = os.path.join(output_dir, f'{sampler}_seed={seed}_perturb={perturb}_noise={noise}_weight={alpha}_tau={beta}_{policy.horizon}-{ahorizon}.json')

    else:
        pass
    json.dump(json_log, open(out_path, 'w'), indent=2, sort_keys=True)

if __name__ == '__main__':
    main()
