# Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the NVIDIA Source Code License [see LICENSE for details].
import re 
import os
import yaml
import csv
import torch
import cv2
import shutil
import argparse
import numpy as np
from typing import List
from multiprocessing import Value
from copy import deepcopy
from loguru import logger
from colosseum import TASKS_TTM_FOLDER
from colosseum.rlbench.utils import ObservationConfigExt
from torch.nn.parallel import DistributedDataParallel as DDP

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["BITSANDBYTES_NOWELCOME"] = "1"

from rlbench.backend import task as rlbench_task
from rlbench.backend.utils import task_file_to_task_class
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.backend.exceptions import BoundaryError, WaypointError, TaskEnvironmentError, NoWaypointsError, DemoError, InvalidActionError
from yarr.utils.stat_accumulator import SimpleAccumulator
from yarr.agents.agent import VideoSummary
from sam2act_colosseum.utils.agent import Summary
import sam2act_colosseum.rvt.mvt_config as default_mvt_cfg
import sam2act_colosseum.rvt.rvt_agent as rvt_agent
import sam2act_colosseum.rvt.config as default_exp_cfg
from sam2act_colosseum.utils.stat_accumulator import SimpleAccumulator
from sam2act_colosseum.rvt.mvt import MVT
from sam2act_colosseum.utils.agent import ActResult
from third_libraries.vggt.avt_vggt.utils.env_utils import (
    CAMERAS,
    SCENE_BOUNDS,
    IMAGE_SIZE,   
    COLOSSEUM_TASKS,   
    EndEffectorPoseViaPlanning2 as EndEffectorPoseViaPlanning,
)
from third_libraries.vggt.avt_vggt.utils.mvt_utils import PreprocessAgent2
from sam2act_colosseum.envs.custom_colosseum_env import CustomColosseumEnv as Env 
from sam2act_colosseum.envs.custom_colosseum_env import get_colosseum_cfg

def natural_sort_key(s):
    return [
        int(text) if text.isdigit() else text.lower()
        for text in re.split(r'(\d+)', s)
    ]

def find_directories_starting_with(directory, prefix):
    matching_directories = []
    for root, dirs, files in os.walk(directory):
        for dir_name in dirs:
            if dir_name.startswith(prefix):
                matching_directories.append(dir_name)
    matching_directories = sorted(matching_directories, key=natural_sort_key)
    return matching_directories

class ReplayTransition(object):
    def __init__(self, observation: dict, action: np.ndarray,
                 reward: float, terminal: bool, timeout: bool,
                 final_observation: dict = None,
                 summaries: List[Summary] = None,
                 info: dict = None):
        self.observation = observation
        self.action = action
        self.reward = reward
        self.terminal = terminal
        self.timeout = timeout
        # final only populated on last timestep
        self.final_observation = final_observation
        self.summaries = summaries or []
        self.info = info


class RolloutGenerator(object):

    def __init__(self, env_device): # env_device = 'cuda:0'
        self._env_device = env_device

    def _get_type(self, x):
        if x.dtype == np.float64:
            return np.float32
        return x.dtype

    def generator(self, env: Env, agent: rvt_agent.RVTAgent,
                  episode_length: int, timesteps: int,
                  eval: bool, eval_demo_seed: int = 0, # i.e. current_variation_id
                  record_enabled: bool = False,
                  replay_ground_truth: bool = False,
                  task_name: str = "",current_rollout: int = 0):    
        
        if eval:
            obs = env.eval_reset_to_demo(task_name, eval_demo_seed, current_rollout)
            if replay_ground_truth:
                actions = env.get_ground_truth_action(eval_demo_seed)
        else:
            obs = env.reset()

        agent.reset()
        obs_history = {k: [np.array(v, dtype=self._get_type(v))] * timesteps for k, v in obs.items()}

        for ep_step in range(episode_length):

            prepped_data = {k:torch.tensor(np.array([v]), device=self._env_device) for k, v in obs_history.items()}
            if not replay_ground_truth:
                act_result = agent.act(prepped_data, deterministic=eval)
            else:
                if ep_step >= len(actions):
                    return
                act_result = ActResult(actions[ep_step])

            agent_obs_elems = {k: np.array(v) for k, v in act_result.observation_elements.items()}
            extra_replay_elements = {k: np.array(v) for k, v in act_result.replay_elements.items()}
            transition = env.step(act_result)

            obs_tp1 = dict(transition.observation)
            timeout = False
            if ep_step == episode_length - 1:
                # If last transition, and not terminal, then we timed out
                timeout = not transition.terminal
                if timeout:
                    transition.terminal = True
                    if "needs_reset" in transition.info:
                        transition.info["needs_reset"] = True

            obs_and_replay_elems = {}
            obs_and_replay_elems.update(obs)
            obs_and_replay_elems.update(agent_obs_elems)
            obs_and_replay_elems.update(extra_replay_elements)

            for k in obs_history.keys():
                obs_history[k].append(transition.observation[k])
                obs_history[k].pop(0)

            transition.info["active_task_id"] = env.active_task_id

            replay_transition = ReplayTransition(
                obs_and_replay_elems, act_result.action, transition.reward,
                transition.terminal, timeout, summaries=transition.summaries,
                info=transition.info)

            if transition.terminal or timeout:
                # If the agent gives us observations then we need to call act
                # one last time (i.e. acting in the terminal state).
                if len(act_result.observation_elements) > 0:
                    prepped_data = {k: torch.tensor([v], device=self._env_device) for k, v in obs_history.items()}
                    act_result = agent.act(prepped_data, deterministic=eval)

                    agent_obs_elems_tp1 = {k: np.array(v) for k, v in act_result.observation_elements.items()}

                    obs_tp1.update(agent_obs_elems_tp1)
                replay_transition.final_observation = obs_tp1

            if record_enabled and transition.terminal or timeout or ep_step == episode_length - 1:
                env._action_mode.arm_action_mode.record_end(env._task._scene, steps=60, step_scene=True)

            obs = dict(transition.observation)

            yield replay_transition

            if transition.info.get("needs_reset", transition.terminal):
                return

def load_agent_state(agent_path, agent=None, only_epoch=False):
    if isinstance(agent, PreprocessAgent2):
        assert not only_epoch
        agent._pose_agent.load_weights(agent_path)
        return 0

    checkpoint = torch.load(agent_path, map_location="cpu")
    epoch = checkpoint["epoch"]

    if not only_epoch:
        if hasattr(agent, "_q"):
            model = agent._q
        elif hasattr(agent, "_network"):
            model = agent._network
        optimizer = agent._optimizer
        lr_sched = agent._lr_sched

        if isinstance(model, DDP):
            model = model.module

        try:
            model.load_state_dict(checkpoint["model_state"])
        except RuntimeError:
            try:
                print(
                    "WARNING: loading states in mvt1. "
                    "Be cautious if you are using a two stage network."
                )
                model.mvt1.load_state_dict(checkpoint["model_state"])
            except RuntimeError:
                print(
                    "WARNING: loading states with strick=False! "
                    "KNOW WHAT YOU ARE DOING!!"
                )
                model.load_state_dict(checkpoint["model_state"], strict=False)

        # if "optimizer_state" in checkpoint:
        #     optimizer.load_state_dict(checkpoint["optimizer_state"])
        # else:
        #     print(
        #         "WARNING: No optimizer_state in checkpoint" "KNOW WHAT YOU ARE DOING!!"
        #     )

        # if "lr_sched_state" in checkpoint:
        #     lr_sched.load_state_dict(checkpoint["lr_sched_state"])
        # else:
        #     print(
        #         "WARNING: No lr_sched_state in checkpoint" "KNOW WHAT YOU ARE DOING!!"
        #     )

    return epoch

def get_eval_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--tasks", type=str, nargs="+", default=["all"]     # 单任务示例：["insert_onto_square_peg"]
    )
    parser.add_argument("--model-folder", type=str,         # 存 model.pth 的绝对路径（不包含 model.pth）
                        default="/RVT/rvt/runs/rvt2/")
    parser.add_argument("--eval-datafolder", type=str,      # 测试数据的路径，basketball_in_hoop_0 的上一级目录
                        default="/data3/grail/colosseum_dataset")
    parser.add_argument(
        "--start-episode",
        type=int,
        default=0,
        help="start to evaluate from which episode",
    )
    parser.add_argument(
        "--eval-episodes",
        type=int,
        default=25,
        help="how many episodes to be evaluated for each task",
    )
    parser.add_argument(
        "--episode-length",
        type=int,
        default=25,
        help="maximum control steps allowed for each episode",
    )

    parser.add_argument("--headless", action="store_true", default=True)
    parser.add_argument("--ground-truth", action="store_true", default=False)
    parser.add_argument("--exp_cfg_path", type=str, default=None)
    parser.add_argument("--mvt_cfg_path", type=str, default=None)
    parser.add_argument("--peract_official", action="store_true")
    parser.add_argument(
        "--peract_model_dir",
        type=str,
        default="runs/peract_official/seed0/weights/600000",
    )
    parser.add_argument("--device", type=int, default=7)
    parser.add_argument("--log-name", type=str, default='')
    parser.add_argument("--model-name", type=str, default="model_14.pth")       # 待测试的 checkpoint
    parser.add_argument("--use-input-place-with-mean", action="store_true")
    parser.add_argument("--save-video", action="store_true", default=True)
    parser.add_argument("--skip", action="store_true", default=False)

    return parser


def load_agent(
    model_path=None,
    peract_official=False,
    peract_model_dir=None,
    exp_cfg_path=None,
    mvt_cfg_path=None,
    eval_log_dir="",
    device=0,
    use_input_place_with_mean=False,
):
    device = f"cuda:{device}"

    if not (peract_official):
        assert model_path is not None

        # load exp_cfg
        model_folder = os.path.join(os.path.dirname(model_path))

        exp_cfg = default_exp_cfg.get_cfg_defaults()
        if exp_cfg_path != None:
            exp_cfg.merge_from_file(exp_cfg_path)
        else:
            exp_cfg.merge_from_file(os.path.join(model_folder, "exp_cfg.yaml"))

        # NOTE: to not use place_with_mean in evaluation
        # needed for rvt-1 but not rvt-2
        if not use_input_place_with_mean:
            # for backward compatibility
            old_place_with_mean = exp_cfg.rvt.place_with_mean
            exp_cfg.rvt.place_with_mean = True

        exp_cfg.freeze()

        # create agent
        if exp_cfg.agent == "our":
            mvt_cfg = default_mvt_cfg.get_cfg_defaults()
            if mvt_cfg_path != None:
                mvt_cfg.merge_from_file(mvt_cfg_path)
            else:
                mvt_cfg.merge_from_file(os.path.join(model_folder, "mvt_cfg.yaml"))

            mvt_cfg.freeze()

            # for rvt-2 we do not change place_with_mean regardless of the arg
            # done this way to ensure backward compatibility and allow the
            # flexibility for rvt-1
            if mvt_cfg.stage_two:
                exp_cfg.defrost()
                exp_cfg.rvt.place_with_mean = old_place_with_mean
                exp_cfg.freeze()

            rvt = MVT(
                renderer_device=device,
                **mvt_cfg,
            )

            agent = rvt_agent.RVTAgent(
                network=rvt.to(device),
                image_resolution=[IMAGE_SIZE, IMAGE_SIZE],
                add_lang=mvt_cfg.add_lang,
                stage_two=mvt_cfg.stage_two,
                rot_ver=mvt_cfg.rot_ver,
                scene_bounds=SCENE_BOUNDS,
                cameras=CAMERAS,
                log_dir=f"{eval_log_dir}/eval_run",
                **exp_cfg.peract,
                **exp_cfg.rvt,
            )

        else:
            raise NotImplementedError

        agent.build(training=False, device=device)
        load_agent_state(model_path, agent)
        agent.eval()

    # elif peract_official:  # load official peract model, using the provided code
    #     try:
    #         model_folder = os.path.join(os.path.abspath(peract_model_dir), "..", "..")
    #         train_cfg_path = os.path.join(model_folder, "config.yaml")
    #         agent = get_official_peract(train_cfg_path, False, device, bs=1)
    #     except FileNotFoundError:
    #         print("Config file not found, trying to load again in our format")
    #         train_cfg_path = "configs/peract_official_config.yaml"
    #         agent = get_official_peract(train_cfg_path, False, device, bs=1)
    #     agent.load_weights(peract_model_dir)
    #     agent.eval()

    print("Agent Information")
    print(agent)
    return agent


@torch.no_grad()
def eval(
    agent,
    tasks,
    eval_datafolder,
    start_episode=0,
    eval_episodes=25,
    episode_length=25,
    replay_ground_truth=False,
    device=0,
    headless=True,
    logging=False,
    log_dir=None,
    verbose=True,
    save_video=False,
):

    agent.eval()
    if isinstance(agent, rvt_agent.RVTAgent):
        agent.load_clip()

    device = f"cuda:{device}"

    if logging:
        assert log_dir is not None
        # create metric saving writer
        csv_file = "eval_results.csv"
        if not os.path.exists(os.path.join(log_dir, csv_file)):
            with open(os.path.join(log_dir, csv_file), "w") as csv_fp:
                fieldnames = ["task", "success rate", "length", "total_transitions"]
                csv_writer = csv.DictWriter(csv_fp, fieldnames=fieldnames)
                csv_writer.writeheader()

    task_files = [
        t.replace(".py", "")
        for t in os.listdir(rlbench_task.TASKS_PATH)
        if t != "__init__.py" and t.endswith(".py")
    ]

    task_classes = []
    if tasks[0] == "all":
        tasks = COLOSSEUM_TASKS
        if verbose:
            print(f"evaluate on {len(tasks)} tasks: ", tasks)

    for task in tasks:
        if task not in task_files:
            raise ValueError("Task %s not recognised!." % task)
        task_classes.append(task_file_to_task_class(task))

    # evaluate agent
    rollout_generator = RolloutGenerator(device)
    stats_accumulator = SimpleAccumulator(eval_video_fps=30)
    current_task_id = -1
    num_tasks = len(tasks)

    scores = []
    final_summaries = []
    for task_id in range(num_tasks):
        for dir_name in find_directories_starting_with(eval_datafolder, tasks[task_id]):
            print("Start evaluation for ", dir_name)
            match = re.fullmatch(r"(.+?)_(\d+)$", dir_name)         # 使用正则表达式匹配任务名和变体ID
            if match:
                current_task = match.group(1)          
                current_variation_id = int(match.group(2))  
                config = get_colosseum_cfg(current_task, current_variation_id)
                data_cfg, env_cfg = config.data, config.env
            else:
                raise ValueError(f"Invalid directory format: {dir_name}")
            task_rewards = []
            for ep in range(start_episode, start_episode + eval_episodes): 
                if not any(f'episode{ep}' in dirs for _, dirs, _ in os.walk(os.path.join(eval_datafolder, dir_name))): 
                    logger.error(f'Evaluating {dir_name} | Episode {ep} | Error: Directory episode{ep} not exist.')
                    continue
                eval_env = Env(
                    task_classes = task_classes,
                    obs_config = ObservationConfigExt(data_cfg),
                    action_mode = MoveArmThenGripper(arm_action_mode=EndEffectorPoseViaPlanning(), gripper_action_mode=Discrete()),
                    headless = headless,
                    path_task_ttms = TASKS_TTM_FOLDER,
                    dataset_root = eval_datafolder,
                    episode_length = episode_length,
                    swap_task_every = eval_episodes,
                    include_lang_goal_in_obs = True,
                    time_in_state = True,
                    record_every_n = 1 if save_video else -1,
                )
                eval_env.eval = True
                eval_env.launch(current_task, env_cfg)

                episode_rollout = []
                generator = rollout_generator.generator(
                    env=eval_env,
                    agent=agent,
                    episode_length=episode_length,
                    timesteps=1,
                    eval=True,
                    eval_demo_seed=current_variation_id,
                    record_enabled=False,
                    replay_ground_truth=False,
                    task_name=dir_name,
                    current_rollout=ep,
                )
                try:
                    for replay_transition in generator:             
                        episode_rollout.append(replay_transition)            
                except StopIteration as e:
                    continue
                except (RuntimeError, IndexError, BoundaryError, WaypointError, NoWaypointsError, DemoError, InvalidActionError, 
                        TaskEnvironmentError) as e:
                    logger.error(f"Evaluating {dir_name} | Episode {ep} | Error: " + str(e))
                    eval_env.shutdown()
                except Exception as e:
                    eval_env.shutdown()
                    raise e
                
                for transition in episode_rollout:
                    stats_accumulator.step(transition, True)
                    current_task_id = transition.info["active_task_id"]
                    assert current_task_id == task_id

                if len(episode_rollout) > 0:
                    reward = episode_rollout[-1].reward
                else:
                    reward = 0.0  # 或标记为无效值
                    logger.error(f"Invalid rollout in Episode {ep}")

                task_rewards.append(reward)
                lang_goal = eval_env._lang_goal
                eval_env.shutdown()
                if verbose == 0:
                    print(f"Evaluating {dir_name} | Episode {ep} | Score: {reward} | Episode Length: {len(episode_rollout)} | Lang Goal: {lang_goal}")

            # report summaries
            summaries = []
            summaries.extend(stats_accumulator.pop())
            task_name = dir_name
            if logging:
                # writer csv first
                with open(os.path.join(log_dir, csv_file), "a") as csv_fp:
                    fieldnames = ["task", "success rate", "length", "total_transitions"]
                    csv_writer = csv.DictWriter(csv_fp, fieldnames=fieldnames)
                    csv_results = {"task": task_name}
                    task_stats = {}
                    for s in summaries:
                        if s.name == "eval_envs/return":
                            csv_results["success rate"] = s.value
                            task_stats["success rate"] = s.value
                        elif s.name == "eval_envs/length":
                            csv_results["length"] = s.value
                            task_stats["length"] = s.value
                        elif s.name == "eval_envs/total_transitions":
                            csv_results["total_transitions"] = s.value
                            task_stats["total_transitions"] = s.value
                        if "eval" in s.name:
                            s.name = "%s/%s" % (s.name, task_name)
                    csv_writer.writerow(csv_results)
                    final_summaries.append({dir_name: task_stats})
            else:
                for s in summaries:
                    if "eval" in s.name:
                        s.name = "%s/%s" % (s.name, task_name)

            if len(summaries) > 0:
                task_score = [
                    s.value for s in summaries if f"eval_envs/return/{task_name}" in s.name
                ][0]
            else:
                task_score = "unknown"

            print(f"[Evaluation] Finished {task_name} | Final Score: {task_score}\n")

            scores.append(task_score)
            print("scores: ", scores)

            if save_video:
                record_fps = 25
                record_folder = os.path.join(log_dir, "videos")
                os.makedirs(record_folder, exist_ok=True)
                video_success_cnt = 0
                video_fail_cnt = 0
                video_cnt = 0
                target_width, target_height = 1280, 720  # 目标分辨率 原来是frame_height, frame_width = video.shape[1], video.shape[2]
                for summary in summaries:
                    if isinstance(summary, VideoSummary):
                        video = deepcopy(summary.value)
                        video = np.transpose(video, (0, 2, 3, 1))
                        video = video[:, :, :, ::-1]
                        if task_rewards[video_cnt] > 99:
                            video_path = os.path.join(
                                record_folder,
                                f"{task_name}_success_{video_success_cnt}.mp4",
                            )
                            video_success_cnt += 1
                        else:
                            video_path = os.path.join(
                                record_folder, f"{task_name}_fail_{video_fail_cnt}.mp4"
                            )
                            video_fail_cnt += 1
                        
                        # 创建视频写入器
                        fourcc = cv2.VideoWriter_fourcc(*'mp4v')  
                        writer = cv2.VideoWriter(
                            video_path, 
                            fourcc,
                            record_fps,
                            (target_width, target_height)
                        )

                        if not writer.isOpened():
                            print(f"编码器不支持")

                        # 写入视频帧
                        for idx in range(len(video)):
                            try:
                                frame = cv2.resize(video[idx], (target_width, target_height), 
                                                    interpolation=cv2.INTER_CUBIC)
                                writer.write(frame)
                            except Exception as e:
                                    print(f"帧处理失败: {str(e)}")
                                    continue
                        
                        writer.release()  # 释放写入器资源
                        video_cnt += 1 

    # also add average scores at the end
    if logging:
        with open(os.path.join(log_dir, csv_file), "a") as csv_fp:
            fieldnames = ["task", "success rate", "length", "total_transitions"]
            csv_writer = csv.DictWriter(csv_fp, fieldnames=fieldnames)
            csv_results = {"task": "average"}
            csv_results["success rate"] = sum(scores) / len(scores)
            csv_writer.writerow(csv_results)

    if logging:
        csv_fp.close()

    # set agent to back train mode
    agent.train()

    # unloading clip to save memory
    if isinstance(agent, rvt_agent.RVTAgent):
        agent.unload_clip()
        agent._network.free_mem()

    return scores, final_summaries


def get_model_index(filename):
    """
    :param filenam: path of file of format /.../model_idx.pth
    :return: idx or None
    """
    if len(filename) >= 9 and filename[-4:] == ".pth":
        try:
            index = int(filename[:-4].split("_")[-1])
        except:
            index = None
    else:
        index = None
    return index


def _eval(args):

    model_paths = []
    if not (args.peract_official):
        assert args.model_name is not None
        model_paths.append(os.path.join(args.model_folder, args.model_name))    # checkpoint 完整的绝对路径
    else:
        model_paths.append(None)

    for model_path in model_paths:
        tasks_to_eval = deepcopy(args.tasks)

        if args.peract_official:
            model_idx = 0
        else:
            model_idx = get_model_index(model_path)
            if model_idx is None:
                model_idx = 0

        if not (args.peract_official):
            agent = load_agent(
                model_path=model_path,
                exp_cfg_path=args.exp_cfg_path,
                mvt_cfg_path=args.mvt_cfg_path,
                eval_log_dir=args.eval_log_dir,
                device=args.device,
                use_input_place_with_mean=args.use_input_place_with_mean,
            )

            agent_eval_log_dir = os.path.join(
                args.eval_log_dir, os.path.basename(model_path).split(".")[0]
            )

        else:
            agent = load_agent(
                peract_official=args.peract_official,
                peract_model_dir=args.peract_model_dir,
                device=args.device,
                use_input_place_with_mean=args.use_input_place_with_mean,
            )
            agent_eval_log_dir = os.path.join(args.eval_log_dir, "final")

        os.makedirs(agent_eval_log_dir, exist_ok=True)
        scores = eval(
            agent=agent,
            tasks=tasks_to_eval,
            eval_datafolder=args.eval_datafolder,
            start_episode=args.start_episode,
            eval_episodes=args.eval_episodes,
            episode_length=args.episode_length,
            replay_ground_truth=args.ground_truth,  # False
            device=args.device,
            headless=args.headless,
            logging=True,
            log_dir=agent_eval_log_dir,
            verbose=True,
            save_video=args.save_video,
        )
        print(f"model {model_path}, scores {scores}")
        task_scores = {}
        for i in range(len(tasks_to_eval)):
            task_scores[tasks_to_eval[i]] = scores[i]

        print("save ", task_scores)
    #     tb.update("eval", model_idx, task_scores)
    #     tb.writer.flush()
    # tb.close()


if __name__ == "__main__":
    parser = get_eval_parser()
    args = parser.parse_args()

    if args.log_name is None:
        args.log_name = "none"

    if not (args.peract_official):
        args.eval_log_dir = os.path.join(args.model_folder, "eval", args.log_name)
    else:
        args.eval_log_dir = os.path.join(args.peract_model_dir, "eval", args.log_name)

    os.makedirs(args.eval_log_dir, exist_ok=True)

    # save the arguments for future reference
    with open(os.path.join(args.eval_log_dir, "eval_config.yaml"), "w") as fp:
        yaml.dump(args.__dict__, fp)

    _eval(args)