import argparse
import sys
import os

# TODO: find a better way for this?
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import hydra
import json
import numpy as np
import pprint
import time
import torch
import wandb
import yaml
from easydict import EasyDict
from hydra.utils import get_original_cwd, to_absolute_path
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader
from transformers import AutoModel, pipeline, AutoTokenizer, logging
from pathlib import Path

from diffusion_policy.env_runner.libero import get_libero_path,benchmark

from diffusion_policy.env_runner.libero.benchmark import get_benchmark
from diffusion_policy.env_runner.libero.envs import OffScreenRenderEnv, SubprocVectorEnv
from diffusion_policy.env_runner.libero.utils.time_utils import Timer
from diffusion_policy.env_runner.libero.utils.video_utils import VideoWriter
from diffusion_policy.env_runner.lifelong.algos import *
from diffusion_policy.env_runner.lifelong.datasets import get_dataset, SequenceVLDataset, GroupedTaskDataset
from diffusion_policy.env_runner.lifelong.metric import (
    evaluate_loss,
    evaluate_success,
    raw_obs_to_tensor_obs,
)
from diffusion_policy.env_runner.lifelong.utils import (
    control_seed,
    safe_device,
    torch_load_model,
    NpEncoder,
    compute_flops,
)

from diffusion_policy.env_runner.lifelong.main import get_task_embs

import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.tensor_utils as TensorUtils

import time


benchmark_map = {
    "libero_10": "LIBERO_10",
    "libero_spatial": "LIBERO_SPATIAL",
    "libero_object": "LIBERO_OBJECT",
    "libero_goal": "LIBERO_GOAL",
}

algo_map = {
    "base": "Sequential",
    "er": "ER",
    "ewc": "EWC",
    "packnet": "PackNet",
    "multitask": "Multitask",
}

policy_map = {
    "bc_rnn_policy": "BCRNNPolicy",
    "bc_transformer_policy": "BCTransformerPolicy",
    "bc_vilt_policy": "BCViLTPolicy",
}


def parse_args():
    parser = argparse.ArgumentParser(description="Evaluation Script")
    parser.add_argument("--experiment_dir", type=str, default="experiments")
    # for which task suite
    parser.add_argument(
        "--benchmark",
        type=str,
        required=True,
        choices=["libero_10", "libero_spatial", "libero_object", "libero_goal"],
    )
    parser.add_argument("--task_id", type=int, required=True)
    # method detail
    parser.add_argument(
        "--algo",
        type=str,
        required=True,
        choices=["base", "er", "ewc", "packnet", "multitask"],
    )
    parser.add_argument(
        "--policy",
        type=str,
        required=True,
        choices=["bc_rnn_policy", "bc_transformer_policy", "bc_vilt_policy"],
    )
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument("--ep", type=int)
    parser.add_argument("--load_task", type=int)
    parser.add_argument("--device_id", type=int)
    parser.add_argument("--save-videos", action="store_true")
    # parser.add_argument('--save_dir',  type=str, required=True)
    args = parser.parse_args()
    args.device_id = "cuda:" + str(args.device_id)
    args.save_dir = f"{args.experiment_dir}_saved"

    if args.algo == "multitask":
        assert args.ep in list(
            range(0, 50, 5)
        ), "[error] ep should be in [0, 5, ..., 50]"
    else:
        assert args.load_task in list(
            range(10)
        ), "[error] load_task should be in [0, ..., 9]"
    return args


# benchmark choices = ["libero_10", "libero_spatial", "libero_object", "libero_goal"],    )

class LiberoImageRunner():
    def __init__(self,
            output_dir,
            test_start_seed=10000,
            benchmark_name='libero_10',
                 n_obs_steps=2):
        self.n_obs_steps = n_obs_steps

        self.test_start_seed=test_start_seed

        benchmark_dict = benchmark.get_benchmark_dict()
        task_suite = benchmark_dict[benchmark_name]()

        # retrieve a specific task
        task_id = 3
        
        task = task_suite.get_task(task_id)

        ### ======================= start evaluation ============================

        # 1. evaluate dataset loss

        # 2. evaluate success rate

        task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)

        env_args = {
            "bddl_file_name": task_bddl_file,
            "camera_heights": data.img_h,
            "camera_widths": data.img_w,
        }

        env_num = 20
        self.env = SubprocVectorEnv(
            [lambda: OffScreenRenderEnv(**env_args) for _ in range(env_num)]
        )

    def run(self,policy):

        log_data=dict()
        env=self.env
        env.reset()
        env.seed(self.test_start_seed)

        init_states_path = os.path.join(
            init_states_folder, task.problem_folder, task.init_states_file
        )
        init_states = torch.load(init_states_path)
        indices = np.arange(env_num) % init_states.shape[0]
        init_states_ = init_states[indices]

        dones = [False] * env_num
        steps = 0
        obs = env.set_init_state(init_states_)
        task_emb = benchmark.get_task_emb(args.task_id)

        num_success = 0
        for _ in range(5):  # simulate the physics without any actions
            env.step(np.zeros((env_num, 7)))

        with torch.no_grad():
            while steps < eval.max_steps:
                steps += 1

                obs_dict = raw_obs_to_tensor_obs(obs, task_emb, cfg)
                with torch.no_grad():
                    action_dict = policy.predict_action(obs_dict)

                np_action_dict = dict_apply(action_dict,
                    lambda x: x.detach().to('cpu').numpy())

                action = np_action_dict['action']

                obs, reward, done, info = env.step(action)
                video_writer.append_vector_obs(
                    obs, dones, camera_name="agentview_image"
                )

                # check whether succeed
                for k in range(env_num):
                    dones[k] = dones[k] or done[k]
                if all(dones):
                    break

            for k in range(env_num):
                log_data[ f'test/sim_max_reward_{k}'] = int(dones[k])
                num_success += int(dones[k])

        success_rate = num_success / env_num
        log_data['test/mean_score'] = success_rate

        env.close()
        return log_data

