from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple
import datetime

import dsrl
import numpy as np
import pyrallis
import torch
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from pyrallis import field
from fsrl.utils import TensorboardLogger

from osrl.algorithms import CDT, CDTTrainer
from osrl.algorithms import RTG_model, RTGTrainer
from osrl.common.exp_util import load_config_and_model, seed_all

# conservative False results: 0.542,0.525;4.807,5.605
# conservative True results: 0.212,0.241；1.547,1.842 (rtg_sample_quantile 1.0 rtg_sample_quantile_end 1.0 每一步都会调用posterior model更新rtg)
# conservative True results: 0.236,0.304；1.733,2.376 (rtg_sample_quantile 1.0 rtg_sample_quantile_end 1.0 只在cost为1时调用posterior model更新rtg)
# conservative True results: 0.156,0.147；1.102,1.466 (rtg_sample_quantile 1.0 rtg_sample_quantile_end 1.0 只在cost为1时调用posterior model更新rtg)
model_paths= [
    "logs/CDT/OfflinePointButton1Gymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-a7a1/CDT_update_steps100000_use_promptFalse-a7a1",
    "logs/CDT/OfflinePointButton1Gymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-c096/CDT_seed1_update_steps100000_use_promptFalse-c096",
    "logs/CDT/OfflinePointButton1Gymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-ed88/CDT_seed2_update_steps100000_use_promptFalse-ed88",
    "logs/CDT/OfflinePointButton2Gymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-8bc2/CDT_update_steps100000_use_promptFalse-8bc2",
    "logs/CDT/OfflinePointButton2Gymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-cacd/CDT_seed1_update_steps100000_use_promptFalse-cacd",
    "logs/CDT/OfflinePointButton2Gymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-ec57/CDT_seed2_update_steps100000_use_promptFalse-ec57",
    "logs/CDT/OfflinePointCircle1Gymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-c1bd/CDT_update_steps100000_use_promptFalse-c1bd",
    "logs/CDT/OfflinePointCircle1Gymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-4c36/CDT_seed1_update_steps100000_use_promptFalse-4c36",
    "logs/CDT/OfflinePointCircle1Gymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-34d1/CDT_seed2_update_steps100000_use_promptFalse-34d1",
    "logs/CDT/OfflinePointCircle2Gymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-f43b/CDT_update_steps100000_use_promptFalse-f43b",
    "logs/CDT/OfflinePointCircle2Gymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-698f/CDT_seed1_update_steps100000_use_promptFalse-698f",
    "logs/CDT/OfflinePointCircle2Gymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-af89/CDT_seed2_update_steps100000_use_promptFalse-af89",
    "logs/CDT/OfflinePointGoal1Gymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-2118/CDT_update_steps100000_use_promptFalse-2118",
    "logs/CDT/OfflinePointGoal1Gymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-c7de/CDT_seed1_update_steps100000_use_promptFalse-c7de",
    "logs/CDT/OfflinePointGoal1Gymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-4421/CDT_seed2_update_steps100000_use_promptFalse-4421",
    "logs/CDT/OfflinePointGoal2Gymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-f8ff/CDT_update_steps100000_use_promptFalse-f8ff",
    "logs/CDT/OfflinePointGoal2Gymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-c648/CDT_seed1_update_steps100000_use_promptFalse-c648",
    "logs/CDT/OfflinePointGoal2Gymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-b69a/CDT_seed2_update_steps100000_use_promptFalse-b69a",
    "logs/CDT/OfflinePointPush1Gymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-3a23/CDT_update_steps100000_use_promptFalse-3a23",
    "logs/CDT/OfflinePointPush1Gymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-6127/CDT_seed1_update_steps100000_use_promptFalse-6127",
    "logs/CDT/OfflinePointPush1Gymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-34fd/CDT_seed2_update_steps100000_use_promptFalse-34fd",
    "logs/CDT/OfflinePointPush2Gymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-1ed3/CDT_update_steps100000_use_promptFalse-1ed3",
    "logs/CDT/OfflinePointPush2Gymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-9e55/CDT_seed1_update_steps100000_use_promptFalse-9e55",
    "logs/CDT/OfflinePointPush2Gymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-f0a9/CDT_seed2_update_steps100000_use_promptFalse-f0a9",
    "logs/CDT/OfflineCarButton1Gymnasium-v0-cost-10/CDT_use_promptFalse-bbc6/CDT_use_promptFalse-bbc6",
    "logs/CDT/OfflineCarButton1Gymnasium-v0-cost-10/CDT_seed1_use_promptFalse-339c/CDT_seed1_use_promptFalse-339c",
    "logs/CDT/OfflineCarButton1Gymnasium-v0-cost-10/CDT_seed2_use_promptFalse-9c8c/CDT_seed2_use_promptFalse-9c8c",
    "logs/CDT/OfflineCarButton2Gymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-78ee/CDT_update_steps100000_use_promptFalse-78ee",
    "logs/CDT/OfflineCarButton2Gymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-5610/CDT_seed1_update_steps100000_use_promptFalse-5610",
    "logs/CDT/OfflineCarButton2Gymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-d1aa/CDT_seed2_update_steps100000_use_promptFalse-d1aa",
    "logs/CDT/OfflineCarCircle1Gymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-0e90/CDT_update_steps100000_use_promptFalse-0e90",
    "logs/CDT/OfflineCarCircle1Gymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-39ea/CDT_seed1_update_steps100000_use_promptFalse-39ea",
    "logs/CDT/OfflineCarCircle1Gymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-0b63/CDT_seed2_update_steps100000_use_promptFalse-0b63",
    "logs/CDT/OfflineCarCircle2Gymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-3027/CDT_update_steps100000_use_promptFalse-3027",
    "logs/CDT/OfflineCarCircle2Gymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-f8e1/CDT_seed1_update_steps100000_use_promptFalse-f8e1",
    "logs/CDT/OfflineCarCircle2Gymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-9011/CDT_seed2_update_steps100000_use_promptFalse-9011",
    "logs/CDT/OfflineCarGoal1Gymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-e933/CDT_update_steps100000_use_promptFalse-e933",
    "logs/CDT/OfflineCarGoal1Gymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-2329/CDT_seed1_update_steps100000_use_promptFalse-2329",
    "logs/CDT/OfflineCarGoal1Gymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-a649/CDT_seed2_update_steps100000_use_promptFalse-a649",
    "logs/CDT/OfflineCarGoal2Gymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-5f13/CDT_update_steps100000_use_promptFalse-5f13",
    "logs/CDT/OfflineCarGoal2Gymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-4b91/CDT_seed1_update_steps100000_use_promptFalse-4b91",
    "logs/CDT/OfflineCarGoal2Gymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-2ef9/CDT_seed2_update_steps100000_use_promptFalse-2ef9",
    "logs/CDT/OfflineCarPush1Gymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-9092/CDT_update_steps100000_use_promptFalse-9092",
    "logs/CDT/OfflineCarPush1Gymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-b731/CDT_seed1_update_steps100000_use_promptFalse-b731",
    "logs/CDT/OfflineCarPush1Gymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-797e/CDT_seed2_update_steps100000_use_promptFalse-797e",
    "logs/CDT/OfflineCarPush2Gymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-442f/CDT_update_steps100000_use_promptFalse-442f",
    "logs/CDT/OfflineCarPush2Gymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-7673/CDT_seed1_update_steps100000_use_promptFalse-7673",
    "logs/CDT/OfflineCarPush2Gymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-7828/CDT_seed2_update_steps100000_use_promptFalse-7828",
    "logs/CDT/OfflineSwimmerVelocityGymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-c311/CDT_update_steps100000_use_promptFalse-c311",
    "logs/CDT/OfflineSwimmerVelocityGymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-061c/CDT_seed1_update_steps100000_use_promptFalse-061c",
    "logs/CDT/OfflineSwimmerVelocityGymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-a0ee/CDT_seed2_update_steps100000_use_promptFalse-a0ee",
    "logs/CDT/OfflineSwimmerVelocityGymnasium-v1-cost-10/CDT_update_steps100000_use_promptFalse-ae23/CDT_update_steps100000_use_promptFalse-ae23",
    "logs/CDT/OfflineSwimmerVelocityGymnasium-v1-cost-10/CDT_seed1_update_steps100000_use_promptFalse-989b/CDT_seed1_update_steps100000_use_promptFalse-989b",
    "logs/CDT/OfflineSwimmerVelocityGymnasium-v1-cost-10/CDT_seed2_update_steps100000_use_promptFalse-880c/CDT_seed2_update_steps100000_use_promptFalse-880c",
    "logs/CDT/OfflineHopperVelocityGymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-26c5/CDT_update_steps100000_use_promptFalse-26c5",
    "logs/CDT/OfflineHopperVelocityGymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-da24/CDT_seed1_update_steps100000_use_promptFalse-da24",
    "logs/CDT/OfflineHopperVelocityGymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-f9d9/CDT_seed2_update_steps100000_use_promptFalse-f9d9",
    "logs/CDT/OfflineHopperVelocityGymnasium-v1-cost-10/CDT_update_steps100000_use_promptFalse-e3ce/CDT_update_steps100000_use_promptFalse-e3ce",
    "logs/CDT/OfflineHopperVelocityGymnasium-v1-cost-10/CDT_seed1_update_steps100000_use_promptFalse-d46b/CDT_seed1_update_steps100000_use_promptFalse-d46b",
    "logs/CDT/OfflineHopperVelocityGymnasium-v1-cost-10/CDT_seed2_update_steps100000_use_promptFalse-073b/CDT_seed2_update_steps100000_use_promptFalse-073b",
    "logs/CDT/OfflineHalfCheetahVelocityGymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-d40a/CDT_update_steps100000_use_promptFalse-d40a",
    "logs/CDT/OfflineHalfCheetahVelocityGymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-7265/CDT_seed1_update_steps100000_use_promptFalse-7265",
    "logs/CDT/OfflineHalfCheetahVelocityGymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-b892/CDT_seed2_update_steps100000_use_promptFalse-b892",
    "logs/CDT/OfflineHalfCheetahVelocityGymnasium-v1-cost-10/CDT_update_steps100000_use_promptFalse-1c0a/CDT_update_steps100000_use_promptFalse-1c0a",
    "logs/CDT/OfflineHalfCheetahVelocityGymnasium-v1-cost-10/CDT_seed1_update_steps100000_use_promptFalse-5bfe/CDT_seed1_update_steps100000_use_promptFalse-5bfe",
    "logs/CDT/OfflineHalfCheetahVelocityGymnasium-v1-cost-10/CDT_seed2_update_steps100000_use_promptFalse-c5f1/CDT_seed2_update_steps100000_use_promptFalse-c5f1",
    "logs/CDT/OfflineWalker2dVelocityGymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-e78a/CDT_update_steps100000_use_promptFalse-e78a",
    "logs/CDT/OfflineWalker2dVelocityGymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-c16e/CDT_seed1_update_steps100000_use_promptFalse-c16e",
    "logs/CDT/OfflineWalker2dVelocityGymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-7551/CDT_seed2_update_steps100000_use_promptFalse-7551",
    "logs/CDT/OfflineWalker2dVelocityGymnasium-v1-cost-10/CDT_update_steps100000_use_promptFalse-18af/CDT_update_steps100000_use_promptFalse-18af",
    "logs/CDT/OfflineWalker2dVelocityGymnasium-v1-cost-10/CDT_seed1_update_steps100000_use_promptFalse-f6e7/CDT_seed1_update_steps100000_use_promptFalse-f6e7",
    "logs/CDT/OfflineWalker2dVelocityGymnasium-v1-cost-10/CDT_seed2_update_steps100000_use_promptFalse-ed2c/CDT_seed2_update_steps100000_use_promptFalse-ed2c",
    "logs/CDT/OfflineAntVelocityGymnasium-v0-cost-10/CDT_update_steps100000_use_promptFalse-eae5/CDT_update_steps100000_use_promptFalse-eae5",
    "logs/CDT/OfflineAntVelocityGymnasium-v0-cost-10/CDT_seed1_update_steps100000_use_promptFalse-1f44/CDT_seed1_update_steps100000_use_promptFalse-1f44",
    "logs/CDT/OfflineAntVelocityGymnasium-v0-cost-10/CDT_seed2_update_steps100000_use_promptFalse-7def/CDT_seed2_update_steps100000_use_promptFalse-7def",
    "logs/CDT/OfflineAntVelocityGymnasium-v1-cost-10/CDT_update_steps100000_use_promptFalse-e2f1/CDT_update_steps100000_use_promptFalse-e2f1",
    "logs/CDT/OfflineAntVelocityGymnasium-v1-cost-10/CDT_seed1_update_steps100000_use_promptFalse-fb17/CDT_seed1_update_steps100000_use_promptFalse-fb17",
    "logs/CDT/OfflineAntVelocityGymnasium-v1-cost-10/CDT_seed2_update_steps100000_use_promptFalse-f193/CDT_seed2_update_steps100000_use_promptFalse-f193"
]

rtg_model_paths = [
    "logs/OfflinePointButton1Gymnasium-v0-cost-10/RTG_model_use_promptFalse-e48e_posterior/RTG_model_use_promptFalse-e48e_posterior",
    "logs/OfflinePointButton2Gymnasium-v0-cost-10/RTG_model-4ab7_posterior/RTG_model-4ab7_posterior",
    "logs/OfflinePointCircle1Gymnasium-v0-cost-10/RTG_model-43d1_posterior/RTG_model-43d1_posterior",
    "logs/OfflinePointCircle2Gymnasium-v0-cost-10/RTG_model-de14_posterior/RTG_model-de14_posterior",
    "logs/OfflinePointGoal1Gymnasium-v0-cost-10/RTG_model-52b5_posterior/RTG_model-52b5_posterior",
    "logs/OfflinePointGoal2Gymnasium-v0-cost-10/RTG_model-305a_posterior/RTG_model-305a_posterior",
    "logs/OfflinePointPush1Gymnasium-v0-cost-10/RTG_model-6913_posterior/RTG_model-6913_posterior",
    "logs/OfflinePointPush2Gymnasium-v0-cost-10/RTG_model-eefa_posterior/RTG_model-eefa_posterior",
    "logs/OfflineCarButton1Gymnasium-v0-cost-10/RTG_model_use_promptFalse-7558_posterior/RTG_model_use_promptFalse-7558_posterior",
    "logs/OfflineCarButton2Gymnasium-v0-cost-10/RTG_model_use_promptFalse-43b9_posterior/RTG_model_use_promptFalse-43b9_posterior",
    "logs/OfflineCarCircle1Gymnasium-v0-cost-10/RTG_model_use_promptFalse-aa82_posterior/RTG_model_use_promptFalse-aa82_posterior",
    "logs/OfflineCarCircle2Gymnasium-v0-cost-10/RTG_model_use_promptFalse-fe8f_posterior/RTG_model_use_promptFalse-fe8f_posterior",
    "logs/OfflineCarGoal1Gymnasium-v0-cost-10/RTG_model_use_promptFalse-92f8_posterior/RTG_model_use_promptFalse-92f8_posterior",
    "logs/OfflineCarGoal2Gymnasium-v0-cost-10/RTG_model_use_promptFalse-1306_posterior/RTG_model_use_promptFalse-1306_posterior",
    "logs/OfflineCarPush1Gymnasium-v0-cost-10/RTG_model_use_promptFalse-fef1_posterior/RTG_model_use_promptFalse-fef1_posterior",
    "logs/OfflineCarPush2Gymnasium-v0-cost-10/RTG_model_use_promptFalse-56b3_posterior/RTG_model_use_promptFalse-56b3_posterior",
    "logs/OfflineSwimmerVelocityGymnasium-v0-cost-10/RTG_model_use_promptFalse-74c4_posterior/RTG_model_use_promptFalse-74c4_posterior",
    # "logs/OfflineSwimmerVelocityGymnasium-v0-cost-10/RTG_model_use_promptFalse-36f1_posterior/RTG_model_use_promptFalse-36f1_posterior",
    "logs/OfflineSwimmerVelocityGymnasium-v1-cost-10/RTG_model_use_promptFalse-9680_posterior/RTG_model_use_promptFalse-9680_posterior",
    # "logs/OfflineHopperVelocityGymnasium-v0-cost-10/RTG_model_use_promptFalse-df8a_posterior/RTG_model_use_promptFalse-df8a_posterior",
    "logs/OfflineHopperVelocityGymnasium-v0-cost-10/RTG_model_use_promptFalse-8d6b_posterior/RTG_model_use_promptFalse-8d6b_posterior",
    # "logs/OfflineHopperVelocityGymnasium-v1-cost-10/RTG_model_use_promptFalse-17dd_posterior/RTG_model_use_promptFalse-17dd_posterior",
    "logs/OfflineHopperVelocityGymnasium-v1-cost-10/RTG_model_use_promptFalse-3721_posterior/RTG_model_use_promptFalse-3721_posterior",
    #"logs/OfflineHalfCheetahVelocityGymnasium-v0-cost-10/RTG_model_use_promptFalse-b016_posterior/RTG_model_use_promptFalse-b016_posterior",
    "logs/OfflineHalfCheetahVelocityGymnasium-v0-cost-10/RTG_model_use_promptFalse-7b60_posterior/RTG_model_use_promptFalse-7b60_posterior",
    "logs/OfflineHalfCheetahVelocityGymnasium-v1-cost-10/RTG_model_use_promptFalse-f7f3_posterior/RTG_model_use_promptFalse-f7f3_posterior",
    # "logs/OfflineHalfCheetahVelocityGymnasium-v1-cost-10/RTG_model_use_promptFalse-1ffb_posterior/RTG_model_use_promptFalse-1ffb_posterior",
    "logs/OfflineWalker2dVelocityGymnasium-v0-cost-10/RTG_model_use_promptFalse-b3d1_posterior/RTG_model_use_promptFalse-b3d1_posterior",
    # "logs/OfflineWalker2dVelocityGymnasium-v0-cost-10/RTG_model_use_promptFalse-d444_posterior/RTG_model_use_promptFalse-d444_posterior",
    # "logs/OfflineWalker2dVelocityGymnasium-v1-cost-10/RTG_model_use_promptFalse-c809_posterior/RTG_model_use_promptFalse-c809_posterior",
    "logs/OfflineWalker2dVelocityGymnasium-v1-cost-10/RTG_model_use_promptFalse-177c_posterior/RTG_model_use_promptFalse-177c_posterior",
    # "logs/OfflineAntVelocityGymnasium-v0-cost-10/RTG_model_use_promptFalse-f5e5_posterior/RTG_model_use_promptFalse-f5e5_posterior",
    "logs/OfflineAntVelocityGymnasium-v0-cost-10/RTG_model_use_promptFalse-2b47_posterior/RTG_model_use_promptFalse-2b47_posterior",
    # "logs/OfflineAntVelocityGymnasium-v1-cost-10/RTG_model_use_promptFalse-fd44_posterior/RTG_model_use_promptFalse-fd44_posterior"
    "logs/OfflineAntVelocityGymnasium-v1-cost-10/RTG_model_use_promptFalse-58cb_posterior/RTG_model_use_promptFalse-58cb_posterior"
]

target_returns = [
    [40,40,40,40],
    [40,40,40,40],
    [50,50,52.5,55],
    [45,45,47.5,50],
    [30,30,30,30],
    [30,30,30,30],
    [15,15,15,15],
    [12,12,12,12],
    [35,35,35,35],
    [40,40,40,40],
    [20,20,22.5,25],
    [20,20,21,22],
    [40,40,40,40],
    [30,30,30,30],
    [15,15,15,15],
    [12,12,12,12],
    [160,160,160,160],
    [160,160,160,160],
    [1750,1750,1750,1750],
    [1750,1750,1750,1750],
    [3000,3000,3000,3000],
    [3000,3000,3000,3000],
    [2800,2800,2800,2800],
    [2800,2800,2800,2800],
    [2800,2800,2800,2800],
    [2800,2800,2800,2800]
]

@dataclass
class EvalConfig:
    task_id: int = 2 #HalfCheetah 60 Hopper 54 Swimmer 48 Ant 72
    path: str = model_paths[task_id]
    rtg_model_path: str = rtg_model_paths[int(task_id/3)]
    # safe_conservative_path: str = "logs/OfflinePointGoal2Gymnasium-v0-cost-10/CDT_use_rewFalse-4ce7/CDT_use_rewFalse-4ce7"
    returns: List[float] = field(default=target_returns[int(task_id/3)], is_mutable=True)
    costs: List[float] = field(default=[10,20,40,80], is_mutable=True)
    # returns: List[float] = field(default=[10,10,10,10,20,20,20,20,30,30,30,30,40,40,40,40], is_mutable=True)
    # costs: List[float] = field(default=[10,20,40,80,10,20,40,80,10,20,40,80,10,20,40,80], is_mutable=True)
    noise_scale: List[float] = None
    eval_episodes: int = 20
    best: bool = False
    device: str = "cuda:5"
    threads: int = 16
    conservative: bool = True
    rtg_sample_num: int = 1000
    rtg_sample_quantile: float = 0.99
    rtg_sample_quantile_end: float = 0.8
    rtg_update_every_step: bool = True
    deterministic: bool = False
    # need_rescale: bool = True


@pyrallis.wrap()
def eval(args: EvalConfig):

    cfg, model = load_config_and_model(args.path, args.best)
    rtg_cfg, model_rtg = load_config_and_model(args.rtg_model_path, True)

    timestamp = datetime.datetime.now().strftime("%y-%m%d-%H%M%S")
    logger = TensorboardLogger(args.path+"/eval", log_txt=True, name=timestamp)
    eval_cfg = asdict(args)
    logger.save_config(eval_cfg, verbose=True)

    seed_all(cfg["seed"])
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    if "Metadrive" in cfg["task"]:
        import gym
    else:
        import gymnasium as gym  # noqa

    env = wrap_env(
        env=gym.make(cfg["task"]),
        reward_scale=cfg["reward_scale"],
    )
    env = OfflineEnvWrapper(env)
    env.set_target_cost(cfg["cost_limit"])

    target_entropy = -env.action_space.shape[0]

    # model & optimizer & scheduler setup
    cdt_model = CDT(
        state_dim=env.observation_space.shape[0],
        action_dim=env.action_space.shape[0],
        max_action=env.action_space.high[0],
        embedding_dim=cfg["embedding_dim"],
        seq_len=cfg["seq_len"],
        episode_len=cfg["episode_len"],
        num_layers=cfg["num_layers"],
        num_heads=cfg["num_heads"],
        attention_dropout=cfg["attention_dropout"],
        residual_dropout=cfg["residual_dropout"],
        embedding_dropout=cfg["embedding_dropout"],
        time_emb=cfg["time_emb"],
        use_rew=cfg["use_rew"],
        use_cost=cfg["use_cost"],
        cost_transform=cfg["cost_transform"],
        add_cost_feat=cfg["add_cost_feat"],
        mul_cost_feat=cfg["mul_cost_feat"],
        cat_cost_feat=cfg["cat_cost_feat"],
        action_head_layers=cfg["action_head_layers"],
        cost_prefix=cfg["cost_prefix"],
        stochastic=cfg["stochastic"],
        init_temperature=cfg["init_temperature"],
        target_entropy=target_entropy,
    )
    cdt_model.load_state_dict(model["model_state"])
    cdt_model.to(args.device)

    # rtg_prior_model = RTG_model(
    #     state_dim=env.observation_space.shape[0],
    #     prompt_dim=rtg_prior_cfg["prompt_dim"],
    #     cost_embedding_dim=rtg_prior_cfg["embedding_dim"],
    #     state_embedding_dim=rtg_prior_cfg["embedding_dim"],
    #     prompt_embedding_dim=rtg_prior_cfg["embedding_dim"],
    #     r_hidden_sizes=rtg_prior_cfg["r_hidden_sizes"],
    #     use_state=rtg_prior_cfg["use_state"],
    #     use_prompt=rtg_prior_cfg["use_prompt"]
    # )
    # rtg_prior_model.load_state_dict(model_prior_rtg["model_state"])
    # rtg_prior_model.to(args.device)
    if args.conservative:
        rtg_model = RTG_model(
            state_dim=env.observation_space.shape[0],
            prompt_dim=rtg_cfg["prompt_dim"],
            cost_embedding_dim=rtg_cfg["embedding_dim"],
            state_embedding_dim=rtg_cfg["embedding_dim"],
            prompt_embedding_dim=rtg_cfg["embedding_dim"],
            r_hidden_sizes=rtg_cfg["r_hidden_sizes"],
            use_state=rtg_cfg["use_state"],
            use_prompt=rtg_cfg["use_prompt"]
        )
        rtg_model.load_state_dict(model_rtg["model_state"])
        rtg_model.to(args.device)
    else:
        rtg_model = None

    trainer = CDTTrainer(cdt_model,
                         env,
                         reward_scale=cfg["reward_scale"],
                         cost_scale=cfg["cost_scale"],
                         cost_reverse=cfg["cost_reverse"],
                         device=args.device,
                         rtg_model=rtg_model,
                         rtg_sample_num=args.rtg_sample_num,
                         rtg_sample_quantile=args.rtg_sample_quantile,
                         rtg_sample_quantile_end=args.rtg_sample_quantile_end,
                         rtg_update_every_step=args.rtg_update_every_step
                         )

    rets = args.returns
    costs = args.costs
    assert len(rets) == len(
        costs
    ), f"The length of returns {len(rets)} should be equal to costs {len(costs)}!"
    total_normalized_ret=0
    total_normalized_cost=0
    total_normalized_ret_ctg_pos=0
    total_normalized_cost_ctg_pos=0
    need_rescale = False
    if "need_rescale" in rtg_cfg.keys() and rtg_cfg["need_rescale"] == True:
        need_rescale = True
    num=0
    for target_ret, target_cost in zip(rets, costs):
        seed_all(cfg["seed"])
        ret, cost, length, target_ret_mean = trainer.evaluate(args.eval_episodes,
                                            target_ret * cfg["reward_scale"],
                                            target_cost * cfg["cost_scale"], conservative=args.conservative, keep_ctg_positive=True, need_rescale=need_rescale, deterministic=args.deterministic, return_target_return=True)
        # ret_ctg_pos, cost_ctg_pos, length_ctg_pos = trainer.evaluate(
        #                 args.eval_episodes, target_ret * cfg["reward_scale"],
        #                 target_cost * cfg["cost_scale"], keep_ctg_positive=True, conservative=args.conservative)
        normalized_ret, normalized_cost = env.get_normalized_score(ret, cost)
        # normalized_ret_ctg_pos, normalized_cost_ctg_pos = env.get_normalized_score(ret_ctg_pos, cost_ctg_pos)
        normalized_cost = cost/target_cost
        # normalized_cost_ctg_pos = cost_ctg_pos/target_cost
        total_normalized_ret += normalized_ret
        total_normalized_cost += normalized_cost
        # total_normalized_ret_ctg_pos += normalized_ret_ctg_pos
        # total_normalized_cost_ctg_pos += normalized_cost_ctg_pos
        print(
            f"Target reward {target_ret}, real reward {ret}, normalized reward: {normalized_ret}; target cost {target_cost}, real cost {cost}, normalized cost: {normalized_cost}; target_ret_mean: {target_ret_mean}"
        )
        logger.store(tab="Target", target_ret=target_ret, target_cost=target_cost)
        logger.store(tab="Result", normalized_reward=normalized_ret, normalized_cost=normalized_cost, real_reward=ret, real_cost=cost, target_ret_mean=target_ret_mean)
        num+=1
        logger.write(num, display=False)
    total_normalized_ret=total_normalized_ret/num
    total_normalized_cost=total_normalized_cost/num
    # total_normalized_ret_ctg_pos=total_normalized_ret_ctg_pos/num
    # total_normalized_cost_ctg_pos=total_normalized_cost_ctg_pos/num
    print(f"Normalized reward: {total_normalized_ret}; normalized cost: {total_normalized_cost}")


if __name__ == "__main__":
    eval()
