import sys
import os
import pathlib

import numpy as np
from diffusion_policy.dataset.base_dataset import BaseImageDataset
from diffusion_policy.dataset.robomimic_replay_image_dataset import RobomimicReplayImageDataset
import hydra
from omegaconf import DictConfig, OmegaConf
import torch
import dill
import wandb
import json
from diffusion_policy.dataset.pusht_image_dynamics_dataset import DynamicsModelDataset
from diffusion_policy.workspace.base_workspace import BaseWorkspace

# Use line-buffering for both stdout and stderr
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)

@hydra.main(config_path=".", config_name="eval_config")
def main(cfg: DictConfig):
    output_dir = cfg.output_dir

    # Check if output directory exists and confirm overwrite
    if os.path.exists(output_dir):
        confirm = input(f"Output path {output_dir} already exists! Overwrite? (y/N): ")
        if confirm.lower() != 'y':
            sys.exit(1)
    pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Save the used configuration to output_dir
    config_save_path = os.path.join(output_dir, 'eval_config.yaml')
    OmegaConf.save(config=cfg, f=config_save_path)
    print(f"Configuration saved to {config_save_path}")

    # Load policy_checkpoint
    with open(cfg.policy_checkpoint, 'rb') as f:
        payload = torch.load(f, pickle_module=dill)
    
    # Update configuration based on payload
    # payload['cfg'].n_action_steps = cfg.n_action_steps
    cfg_task_env_runner = payload['cfg']
    cfg_task_env_runner.n_action_steps = cfg.n_action_steps
    cfg_task_env_runner.task.env_runner.n_action_steps = cfg.n_action_steps
    cfg_task_env_runner.policy.n_action_steps = cfg.n_action_steps

    cfg_task_env_runner.task.env_runner.n_test = cfg.n_test
    cfg_task_env_runner.task.env_runner.n_test_vis = cfg.n_test
    # cfg_task_env_runner.task.env_runner.n_action_steps = 15
    cfg_task_env_runner.task.env_runner.n_train = 0
    cfg_task_env_runner.task.env_runner.n_train_vis = 0
    cfg_task_env_runner.task.env_runner.test_start_seed = cfg.test_start_seed
    if 'tool_hang' in cfg.policy_checkpoint:
        cfg_task_env_runner.task.env_runner.max_steps = 780
    elif 'pushT' in cfg.policy_checkpoint:
        cfg_task_env_runner.task.env_runner.max_steps = 400
        # cfg_task_env_runner.task.env_runner.max_steps = 500
    cfg_task_env_runner.task.env_runner.n_envs = cfg_task_env_runner.task.env_runner.n_test + cfg_task_env_runner.task.env_runner.n_train

    # Initialize workspace
    cls = hydra.utils.get_class(cfg_task_env_runner._target_)
    workspace = cls(cfg_task_env_runner, output_dir=output_dir)
    workspace: BaseWorkspace
    workspace.load_payload(payload, exclude_keys=None, include_keys=None)
    
    # Get policy from workspace
    policy = workspace.model
    if cfg_task_env_runner.training.use_ema:
        policy = workspace.ema_model

    device = torch.device(cfg.device)
    policy.to(device)
    policy.eval()
    
    normalizer_dir = os.path.dirname(os.path.dirname(cfg.policy_checkpoint))
    normalizer_path = os.path.join(normalizer_dir, 'normalizer.pth')
    policy.normalizer.load_state_dict(torch.load(normalizer_path))
    policy.normalizer.to(device)
    # demo_dataset: BaseImageDataset
    # demo_dataset = hydra.utils.instantiate(payload['cfg'].task.dataset)
    # policy.normalizer = demo_dataset.get_normalizer().to(device)

    # Set OOD quantification if required
    dataset_target = payload['cfg'].task.dataset._target_
    if cfg.avoid_ood:
        if 'PushT' in dataset_target:
            policy.initialize_pusht_planner(
                demo_dataset_config=payload['cfg'].task.dataset,
                dynamics_model_ckpt=cfg.dynamics_model_checkpoint,
                decoder_path=cfg.decoder_path,
                value_func_path=cfg.value_func_path,
                action_step=cfg_task_env_runner.n_action_steps,
                output_dir=cfg.output_dir,
                method=cfg.method,
                guidance_start_timestep=cfg.guidance_start_timestep,
                guidance_scale=cfg.guidance_scale,
                threshold=cfg.threshold,
                )
        else:
            policy.initialize_robomimic_planner(
                demo_dataset_config=payload['cfg'].task.dataset,
                dynamics_model_ckpt=cfg.dynamics_model_checkpoint,
                decoder_path=cfg.decoder_path,
                value_func_path=cfg.value_func_path,
                action_step=cfg_task_env_runner.n_action_steps,
                output_dir=cfg.output_dir,
                method=cfg.method,
                guidance_start_timestep=cfg.guidance_start_timestep,
                guidance_scale=cfg.guidance_scale,
                threshold=cfg.threshold,
                )
    else:
        policy.initialize_robomimic_planner(
            demo_dataset_config=payload['cfg'].task.dataset,
            dynamics_model_ckpt=cfg.dynamics_model_checkpoint,
            decoder_path=cfg.decoder_path,
            value_func_path=cfg.value_func_path,
            action_step=cfg_task_env_runner.n_action_steps,
            output_dir=cfg.output_dir,
            method=cfg.method,
            guidance_start_timestep=cfg.guidance_start_timestep,
            guidance_scale=cfg.guidance_scale,
            threshold=cfg.threshold,
            ) 
    # Run evaluation
    if 'PushT' not in dataset_target: 
        if not cfg.avoid_ood and not cfg.save_hdf5 and not cfg.perturb:
            cfg_task_env_runner.task.env_runner._target_ = 'diffusion_policy.env_runner.robomimic_image_runner.RobomimicImageRunner'
        else:
            cfg_task_env_runner.task.env_runner._target_ = 'diffusion_policy.env_runner.robomimic_image_sequential_runner.SequentialRobomimicImageRunner'
    else:
        if not cfg.avoid_ood and not cfg.save_hdf5 and not cfg.perturb:
            cfg_task_env_runner.task.env_runner._target_ = 'diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner'
        else:
            cfg_task_env_runner.task.env_runner._target_ = 'diffusion_policy.env_runner.pusht_image_sequential_runner.SequentialPushTImageRunner'
    env_runner = hydra.utils.instantiate(
        cfg_task_env_runner.task.env_runner,
        output_dir=output_dir
    )

    if cfg.save_hdf5:
        env_runner.save_hdf5 = True
        if 'PushT' not in dataset_target: 
            env_runner.output_path = os.path.join(output_dir, 'eval_results.hdf5')
        else:
            env_runner.output_path = os.path.join(output_dir, 'eval_results.zarr')

    # if cfg.perturb:
    #     env_runner.perturb = True
    #     env_runner.perturb_prob = cfg.perturb_prob
    #     env_runner.perturb_mag = cfg.perturb_mag

    runner_log = env_runner.run(
        policy,
        avoid_ood=cfg.avoid_ood,
        num_samples=cfg.num_samples,
        method=cfg.method,
    )
    
    # # policy.save_nn_video()
    # ood_score_array = policy.ood_score_array
    # ood_score_path = os.path.join(output_dir, 'ood_score_array.npy')
    # np.save(ood_score_path, ood_score_array)

    # Save evaluation results separately
    results = {}
    for key, value in runner_log.items():
        if isinstance(value, wandb.sdk.data_types.video.Video):
            results[key] = value._path
        else:
            results[key] = value
    
    results_path = os.path.join(output_dir, 'eval_results.json')
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2, sort_keys=True)

    # if cfg.avoid_ood:
    #     # policy.check_dynamics_loss(cfg.use_history)
    #     policy.save_nn_video_robomimic(save_dir=output_dir)

    print(f"Evaluation results saved to {results_path}")

if __name__ == '__main__':
    main()
