import functools
import json
import os
import pathlib
import sys

os.environ["MUJOCO_GL"] = "egl"

# os.environ["HYDRA_FULL_ERROR"] = "1"
os.environ["WANDB_API_KEY"]=""
os.environ["WANDB_USER_EMAIL"]=""
os.environ["WANDB_USERNAME"]=""

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import hydra
import torch
import wandb
from hydra.utils import call, instantiate
from omegaconf import OmegaConf
from termcolor import colored

from lift3d.envs import Evaluator
from lift3d.helpers.common import Logger, WandBLogger, set_seed
from lift3d.helpers.pytorch import AverageMeter, log_params_to_file

import imageio
import numpy as np
from datetime import datetime
from typing import Optional

def save_video_locally(frames: np.ndarray, 
                      fps: int = 30, 
                      output_dir: str = "videos",
                      filename: Optional[str] = None) -> str:
    """
    将视频帧保存为本地视频文件
    
    Args:
        frames: 视频帧数组，形状为 (T, C, H, W) 或 (T, H, W, C)
        fps: 帧率
        output_dir: 输出目录
        filename: 文件名（可选），如果为None则自动生成
    
    Returns:
        str: 保存的视频文件路径
    """
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    
    # 自动生成文件名（如果未提供）
    if filename is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"validation_{timestamp}.mp4"
    
    # 确保文件扩展名
    if not filename.endswith('.mp4'):
        filename += '.mp4'
    
    video_path = os.path.join(output_dir, filename)
    
    # 确保帧的格式正确 (T, H, W, C)
    if frames.ndim == 4 and frames.shape[1] in [1, 3, 4]:  # (T, C, H, W)
        frames = frames.transpose(0, 2, 3, 1)
    
    # 保存视频
    with imageio.get_writer(video_path, fps=fps) as writer:
        for frame in frames:
            # 确保帧的数据类型和范围正确
            if frame.dtype != np.uint8:
                if frame.max() <= 1.0:  # 假设是 [0, 1] 范围
                    frame = (frame * 255).astype(np.uint8)
                else:  # 假设是 [0, 255] 范围但数据类型不对
                    frame = frame.astype(np.uint8)
            writer.append_data(frame)
    
    return video_path

@hydra.main(version_base=None, config_path="../config", config_name="evaluate_metaworld")
def main(config):
    #############################
    # log important information #
    #############################
    Logger.log_info(
        f'Running {colored(pathlib.Path(__file__).absolute(), "red")} with following config:'
    )
    Logger.log_info(f'Task: {colored(config.task_name, "green")}')
    Logger.log_info(f'Image size: {colored(config.image_size, "green")}')
    Logger.log_info(
        f'Agent: {colored(config.agent.name, color="green")}\n{json.dumps(OmegaConf.to_container(config.agent, resolve=True), indent=4)}'
    )
    Logger.log_info(
        f'Benchmark: {colored(config.benchmark.name, color="green")}\n{json.dumps(OmegaConf.to_container(config.benchmark, resolve=True), indent=4)}'
    )
    Logger.print_seperator()

    ############
    # set seed #
    ############
    set_seed(config.seed)

    ##########################
    # evaluator #
    ##########################
    evaluator: Evaluator = instantiate(
        config=config.benchmark.evaluator_instantiate_config,
        task_name=config.task_name,
    )

    #########
    # Model #
    #########
    robot_state_dim = 4
    action_dim = 4
    model = instantiate(
        config=config.agent.instantiate_config,
        robot_state_dim=robot_state_dim,
        action_dim=action_dim,
    )
    
    checkpoint_path = config.path
    model.load_state_dict(torch.load(checkpoint_path))
    model.eval()
    model = model.to(config.device)

    ##############
    # Evaluation #
    ##############
    epoch_logging_info = {"epoch_step": 1}
    avg_success, avg_rewards = evaluator.evaluate(
        config.evaluation.validation_trajs_num, model
    )
    video_path = save_video_locally(
        frames=evaluator.env.get_frames().transpose(0, 3, 1, 2),
        fps=30,
        output_dir="/data2/zehao/LIFT3D/videos/dial-turn_structpolicy"
    )
    epoch_logging_info.update(
        {
            "validation/success": avg_success,
            "validation/rewards": avg_rewards,
            "validation/video_path": video_path,
        }
    )
    evaluator.callback(epoch_logging_info)
    Logger.log_info(
        f"avg_success={avg_success}, "
        f"avg_rewards={avg_rewards}, "
    )


if __name__ == "__main__":
    main()
