import os
import uuid
import types
from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

import bullet_safety_gym  # noqa
import dsrl
import gymnasium as gym  # noqa
import gym as gym_org
import numpy as np
import pyrallis
import torch
from dsrl.infos import DENSITY_CFG
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from fsrl.utils import WandbLogger
from fsrl.utils import TensorboardLogger
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa

from examples.configs.cdt_configs import CDT_DEFAULT_CONFIG, CDTTrainConfig
from osrl.algorithms import State_AE, Action_AE, inverse_dynamics_model, ActionAETrainer, StateAETrainer
from osrl.algorithms import CDT, CDTTrainer, CDT_with_action_AE, MTCDT, MTCDTTrainer, PromptCDT, PromptCDTTrainer
from osrl.algorithms import SafetyAwareEncoder, MultiHeadDecoder, ContextEncoderTrainer, SimpleMlpEncoder
from osrl.common import SequenceDataset, TransitionDataset
from osrl.common.exp_util import auto_name, seed_all, load_config_and_model


@pyrallis.wrap()
def train(args: CDTTrainConfig):
    tasks = ["OfflinePointButton1Gymnasium-v0","OfflinePointButton2Gymnasium-v0","OfflinePointCircle1Gymnasium-v0","OfflinePointCircle2Gymnasium-v0",
                  "OfflinePointGoal1Gymnasium-v0","OfflinePointGoal2Gymnasium-v0","OfflinePointPush1Gymnasium-v0","OfflinePointPush2Gymnasium-v0",
                  "OfflineHalfCheetahVelocityGymnasium-v0","OfflineHalfCheetahVelocityGymnasium-v1","OfflineHopperVelocityGymnasium-v0","OfflineHopperVelocityGymnasium-v1",
                  "OfflineCarButton1Gymnasium-v0","OfflineCarButton2Gymnasium-v0","OfflineCarCircle1Gymnasium-v0","OfflineCarCircle2Gymnasium-v0",
                  "OfflineCarGoal1Gymnasium-v0","OfflineCarGoal2Gymnasium-v0","OfflineCarPush1Gymnasium-v0","OfflineCarPush2Gymnasium-v0",
                  "OfflineAntVelocityGymnasium-v0","OfflineAntVelocityGymnasium-v1","OfflineSwimmerVelocityGymnasium-v0","OfflineSwimmerVelocityGymnasium-v1",
                  "OfflineWalker2dVelocityGymnasium-v0","OfflineWalker2dVelocityGymnasium-v1"]
    task_names = ["PointButton1","PointButton2","PointCircle1","PointCircle2","PointGoal1","PointGoal2","PointPush1","PointPush2",
                "HalfCheetahVel-v0","HalfCheetahVel-v1","HopperVel-v0","HopperVel-v1",
                "CarButton1","CarButton2","CarCircle1","CarCircle2","CarGoal1","CarGoal2","CarPush1","CarPush2",
                "AntVel-v0","AntVel-v1","SwimmerVel-v0","SwimmerVel-v1","Walker2dVel-v0","Walker2dVel-v1"]
    task_envs = [0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12]
    state_encoder_paths = [
        "logs/PointButtonGymnasium-cost-10/sa_encoder-718f/sa_encoder-718f_state_AE",
        "logs/PointButtonGymnasium-cost-10/sa_encoder-718f/sa_encoder-718f_state_AE",
        "logs/PointCircleGymnasium-cost-10/sa_encoder-e510/sa_encoder-e510_state_AE",
        "logs/PointCircleGymnasium-cost-10/sa_encoder-e510/sa_encoder-e510_state_AE",
        "logs/PointGoalGymnasium-cost-10/sa_encoder-0739/sa_encoder-0739_state_AE",
        "logs/PointGoalGymnasium-cost-10/sa_encoder-0739/sa_encoder-0739_state_AE",
        "logs/PointPushGymnasium-cost-10/sa_encoder-0710/sa_encoder-0710_state_AE",
        "logs/PointPushGymnasium-cost-10/sa_encoder-0710/sa_encoder-0710_state_AE",
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_state_AE",
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_state_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_state_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_state_AE",
        "logs/CarButtonGymnasium-cost-10/sa_encoder-0b6c/sa_encoder-0b6c_state_AE",
        "logs/CarButtonGymnasium-cost-10/sa_encoder-0b6c/sa_encoder-0b6c_state_AE",
        "logs/CarCircleGymnasium-cost-10/sa_encoder-8727/sa_encoder-8727_state_AE",
        "logs/CarCircleGymnasium-cost-10/sa_encoder-8727/sa_encoder-8727_state_AE",
        "logs/CarGoalGymnasium-cost-10/sa_encoder-aa9d/sa_encoder-aa9d_state_AE",
        "logs/CarGoalGymnasium-cost-10/sa_encoder-aa9d/sa_encoder-aa9d_state_AE",
        "logs/CarPushGymnasium-cost-10/sa_encoder-cda6/sa_encoder-cda6_state_AE",
        "logs/CarPushGymnasium-cost-10/sa_encoder-cda6/sa_encoder-cda6_state_AE",
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_state_AE",
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_state_AE",
        "logs/SwimmerVelocityGymnasium-cost-10/sa_encoder-8a6f/sa_encoder-8a6f_state_AE",
        "logs/SwimmerVelocityGymnasium-cost-10/sa_encoder-8a6f/sa_encoder-8a6f_state_AE",
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_state_AE",
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_state_AE"
    ]
    action_encoder_paths = [
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_action_AE",
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_action_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_action_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_action_AE",
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_action_AE",
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_action_AE",
        None,
        None,
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_action_AE",
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_action_AE"
    ]
    episode_lens = [1000,1000,500,500,1000,1000,1000,1000,1000,1000,1000,1000,
                    1000,1000,500,500,1000,1000,1000,1000,1000,1000,1000,1000,1000,1000]
    state_dims = [76,76,28,28,60,60,76,76,17,17,11,11,88,88,40,40,72,72,88,88,27,27,8,8,17,17]
    action_dims = [2,2,2,2,2,2,2,2,6,6,3,3,2,2,2,2,2,2,2,2,8,8,2,2,6,6]
    env_state_dims = [76,28,60,76,17,11,88,40,72,88,27,8,17]
    env_action_dims = [2,2,2,2,6,3,2,2,2,2,8,2,6]
    target_returns = [((40.0, 20), (40.0, 40), (40.0, 80)),((40.0, 20), (40.0, 40), (40.0, 80)),((50.0, 20), (52.5, 40), (55.0, 80)),((45.0, 20), (47.5, 40), (50.0, 80)),
                      ((30.0, 20), (30.0, 40), (30.0, 80)),((30.0, 20), (30.0, 40), (30.0, 80)),((15.0, 20), (15.0, 40), (15.0, 80)),((12.0, 20), (12.0, 40), (12.0, 80)),
                      ((3000.0, 20), (3000.0, 40), (3000.0, 80)),((3000.0, 20), (3000.0, 40), (3000.0, 80)),((1750.0, 20), (1750.0, 40), (1750.0, 80)),((1750.0, 20), (1750.0, 40), (1750.0, 80)),
                      ((40.0, 20), (40.0, 40), (40.0, 80)),((40.0, 20), (40.0, 40), (40.0, 80)),((20.0, 20), (22.5, 40), (25.0, 80)),((20.0, 20), (21.0, 40), (22.0, 80)),
                      ((40.0, 20), (40.0, 40), (40.0, 80)),((30.0, 20), (30.0, 40), (30.0, 80)),((15.0, 20), (15.0, 40), (15.0, 80)),((12.0, 20), (12.0, 40), (12.0, 80)),
                      ((2800.0, 20), (2800.0, 40), (2800.0, 80)),((2800.0, 20), (2800.0, 40), (2800.0, 80)),((160.0, 20), (160.0, 40), (160.0, 80)),((160.0, 20), (160.0, 40), (160.0, 80)),
                      ((2800.0, 20), (2800.0, 40), (2800.0, 80)),((2800.0, 20), (2800.0, 40), (2800.0, 80))]
    degs=[0,0,1,1,0,1,0,0,1,1,1,1,0,0,1,1,1,1,0,0,1,1,1,1,1,1]
    max_rewards=[45.0,50.0,65.0,55.0,35.0,35.0,20,15,3000,3000,2000,2000,45,50,30,30,50,35,20,15,3000,3000,250,250,3600,3600]
    max_rew_decreases=[5,10,5,5,5,5,5,3,500,500,300,300,10,10,10,10,5,5,5,3,500,500,50,50,800,800]
    min_rewards=[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]
    
    # tasks=[tasks[-1]]
    # task_names=[task_names[-1]]
    # task_envs=[0]
    # episode_lens=[episode_lens[-1]]
    # state_dims=[state_dims[-1]]
    # action_dims=[action_dims[-1]]
    # env_state_dims=[env_state_dims[-1]]
    # env_action_dims=[env_action_dims[-1]]
    # target_returns=[target_returns[-1]]
    # state_encoder_paths=[state_encoder_paths[-1]]
    # action_encoder_paths=[action_encoder_paths[-1]]

    

    # update config
    args.task = tasks[0]
    cfg, old_cfg = asdict(args), asdict(CDTTrainConfig())
    differing_values = {key: cfg[key] for key in cfg.keys() if cfg[key] != old_cfg[key]}
    cfg = asdict(CDT_DEFAULT_CONFIG[args.task]())
    cfg.update(differing_values)
    args = types.SimpleNamespace(**cfg)
    task_num = len(tasks)

    # print(args.device)
    # assert False
    # setup logger
    default_cfg = asdict(CDT_DEFAULT_CONFIG[args.task]())
    if args.name is None:
        args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
    if args.group is None:
        args.group = "PromptCDT" + "-task_num-" + str(task_num)
    # if args.use_prompt:
    #     args.group+="_use_prompt"
    if args.logdir is not None:
        args.logdir = os.path.join(args.logdir, args.group, args.name)
    # logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
    logger = TensorboardLogger(args.logdir, log_txt=True, name=args.name)
    logger.save_config(cfg, verbose=args.verbose)
        
    # set seed
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    # initialize environment
    env_ls=[]
    data_ls=[]
    target_entropy_ls=[]
    for task in tasks:
        temp_env = gym.make(task)
        temp_env.set_target_cost(args.cost_limit)
        env_ls.append(temp_env)
        temp_data = temp_env.get_dataset()
        data_ls.append(temp_data)
        target_entropy_ls.append(-temp_env.action_space.shape[0])

    cbins, rbins, max_npb, min_npb = None, None, None, None
    if args.density != 1.0:
        assert False
        density_cfg = DENSITY_CFG[args.task + "_density" + str(args.density)]
        cbins = density_cfg["cbins"]
        rbins = density_cfg["rbins"]
        max_npb = density_cfg["max_npb"]
        min_npb = density_cfg["min_npb"]
    # for i in range(len(tasks)):
    for i in range(len(tasks)):
        data_ls[i] = env_ls[i].pre_process_data(data_ls[i],
                                    args.outliers_percent,
                                    args.noise_scale,
                                    args.inpaint_ranges,
                                    args.epsilon,
                                    args.density,
                                    cbins=cbins,
                                    rbins=rbins,
                                    max_npb=max_npb,
                                    min_npb=min_npb)

    # wrapper
    for i in range(len(tasks)):
        temp_env = env_ls[i]
        temp_env = wrap_env(
            env=temp_env,
            reward_scale=args.reward_scale,
        )
        temp_env = OfflineEnvWrapper(temp_env)
        env_ls[i] = temp_env

    state_encoder_ls = []
    action_encoder_ls = []
    pretrained_se_ls = []
    pretrained_ae_ls = []
    for i in range(task_envs[-1]+1):
        # linear only is important
        state_encoder = State_AE(
            state_dim=env_state_dims[i],
            encode_dim=args.state_encode_dim,
            hidden_sizes=args.state_encoder_hidden_sizes,
            # linear_only=True
        )
        state_encoder.to(args.device)
        # decoder linear only is important
        action_encoder = Action_AE(
            action_dim=env_action_dims[i],
            encode_dim=args.action_encode_dim,
            hidden_sizes=args.action_encoder_hidden_sizes,
            require_tanh=False,
            decode_mu_std=True,
            # linear_only=True,
            # decoder_linear_only=True
        )
        action_encoder.to(args.device)
        state_encoder_ls.append(state_encoder)
        action_encoder_ls.append(action_encoder)

    for i in range(len(tasks)):
        senc_cfg, senc_model = load_config_and_model(state_encoder_paths[i], True, device=torch.device("cpu"))
        pretrained_se = State_AE(
            state_dim=state_dims[i],
            encode_dim=senc_cfg["state_encode_dim"],
            hidden_sizes=senc_cfg["state_encoder_hidden_sizes"]
        )
        pretrained_se.load_state_dict(senc_model["model_state"])
        pretrained_se.eval()
        pretrained_se_ls.append(pretrained_se)

        if action_encoder_paths[i] is not None:
            aenc_cfg, aenc_model = load_config_and_model(action_encoder_paths[i], True, device=torch.device("cpu"))
            pretrained_ae = Action_AE(
                action_dim=action_dims[i],
                encode_dim=aenc_cfg["action_encode_dim"],
                hidden_sizes=aenc_cfg["action_encoder_hidden_sizes"]
            )
            pretrained_ae.load_state_dict(aenc_model["model_state"])
            pretrained_ae.eval()
            pretrained_ae_ls.append(pretrained_ae)
        else:
            pretrained_ae_ls.append(None)
    
    enc_cfg, enc_model = load_config_and_model(args.context_encoder_path, False, device=torch.device(args.device))
    enc_cfg = types.SimpleNamespace(**enc_cfg)
    if not enc_cfg.simple_mlp:
        encoder=SafetyAwareEncoder(
            enc_cfg.state_encoding_dim*2+enc_cfg.action_encoding_dim+1,
            enc_cfg.context_encoder_hidden_sizes,
            enc_cfg.context_encoding_dim,
            simple_gate=enc_cfg.simple_gate
            ).to(args.device)
    else:
        encoder=SimpleMlpEncoder(
            enc_cfg.state_encoding_dim*2+enc_cfg.action_encoding_dim+2,
            enc_cfg.context_encoder_hidden_sizes,
            enc_cfg.context_encoding_dim
            ).to(args.device)
    encoder.load_state_dict(enc_model["encoder_state"])
    encoder.eval()
        # if train_action_encoder:
        #     action_encoder.to(args.device)
        #     final_path=args.action_encoder_path.rfind('/')
        #     new_name="action_encoder_after_pretrain"
        #     action_encoder_logdir=args.action_encoder_path[:final_path]+"/"+new_name
        #     action_encoder_logger = TensorboardLogger(action_encoder_logdir, log_txt=True, name=new_name)
        # else:
        #     action_encoder.eval()

    # model & optimizer & scheduler setup
    state_dim = args.state_encode_dim
    if args.prompt_concat:
        state_dim += args.prompt_dim 
    cdt_model = CDT(
        state_dim=state_dim,
        action_dim=args.action_encode_dim,
        max_action=env_ls[0].action_space.high[0],
        embedding_dim=args.embedding_dim,
        seq_len=args.seq_len + args.prompt_seq_len,
        episode_len=args.episode_len,
        num_layers=args.num_layers,
        num_heads=args.num_heads,
        attention_dropout=args.attention_dropout,
        residual_dropout=args.residual_dropout,
        embedding_dropout=args.embedding_dropout,
        time_emb=args.time_emb,
        use_rew=args.use_rew,
        use_cost=args.use_cost,
        cost_transform=args.cost_transform,
        add_cost_feat=args.add_cost_feat,
        mul_cost_feat=args.mul_cost_feat,
        cat_cost_feat=args.cat_cost_feat,
        action_head_layers=args.action_head_layers,
        cost_prefix=args.cost_prefix,
        stochastic=args.stochastic,
        init_temperature=args.init_temperature,
        target_entropy=target_entropy_ls,
        use_prompt=False,
        prompt_prefix=args.prompt_prefix,
        prompt_concat=args.prompt_concat,
        prompt_dim=args.prompt_dim
    ).to(args.device)

    model = PromptCDT(cdt_model, action_encoder_ls, state_encoder_ls, device=args.device)
    model.to(args.device)
    # model.load_state_dict(model_cdt["model_state"])
    # def checkpoint_fn_action():
    #     return {"model_state": action_encoder.state_dict()}
    def checkpoint_fn():
        return {"model_state": model.state_dict()}
    # else:
    #     model = cdt_model
    #     def checkpoint_fn():
    #         return {"model_state": model.state_dict()}
    print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

    # def checkpoint_fn():
    #     return {"model_state": model.state_dict()}

    logger.setup_checkpoint_fn(checkpoint_fn)

    # trainer
    trainer = PromptCDTTrainer(model,
                         env_ls,
                         logger=logger,
                         learning_rate=args.learning_rate,
                         weight_decay=args.weight_decay,
                         betas=args.betas,
                         clip_grad=args.clip_grad,
                         lr_warmup_steps=args.lr_warmup_steps,
                         reward_scale=args.reward_scale,
                         cost_scale=args.cost_scale,
                         loss_cost_weight=args.loss_cost_weight,
                         loss_state_weight=args.loss_state_weight,
                         cost_reverse=args.cost_reverse,
                         no_entropy=args.no_entropy,
                         device=args.device)

    ct = lambda x: 70 - x if args.linear else 1 / (x + 10)

    dataloader_iter_ls=[]
    prompt_dataloader_iter_ls=[]
    for i in range(len(tasks)):
        dataset = SequenceDataset(
            data_ls[i],
            seq_len=args.seq_len,
            reward_scale=args.reward_scale,
            cost_scale=args.cost_scale,
            deg=degs[i],
            pf_sample=args.pf_sample,
            max_rew_decrease=max_rew_decreases[i],
            beta=args.beta,
            augment_percent=args.augment_percent,
            cost_reverse=args.cost_reverse,
            max_reward=max_rewards[i],
            min_reward=min_rewards[i],
            pf_only=args.pf_only,
            rmin=args.rmin,
            cost_bins=args.cost_bins,
            npb=args.npb,
            cost_sample=args.cost_sample,
            cost_transform=ct,
            start_sampling=args.start_sampling,
            prob=args.prob,
            random_aug=args.random_aug,
            aug_rmin=args.aug_rmin,
            aug_rmax=args.aug_rmax,
            aug_cmin=args.aug_cmin,
            aug_cmax=args.aug_cmax,
            cgap=args.cgap,
            rstd=args.rstd,
            cstd=args.cstd
        )

        trainloader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            pin_memory=True,
            num_workers=0,
        )
        trainloader_iter = iter(trainloader)
        dataloader_iter_ls.append(trainloader_iter)

        prompt_dataset = SequenceDataset(
            data_ls[i],
            seq_len=args.prompt_seq_len,
            reward_scale=args.reward_scale,
            cost_scale=args.cost_scale,
            deg=degs[i],
            pf_sample=args.pf_sample,
            max_rew_decrease=max_rew_decreases[i],
            beta=args.beta,
            augment_percent=args.augment_percent,
            cost_reverse=args.cost_reverse,
            max_reward=max_rewards[i],
            min_reward=min_rewards[i],
            pf_only=args.pf_only,
            rmin=args.rmin,
            cost_bins=args.cost_bins,
            npb=args.npb,
            cost_sample=args.cost_sample,
            cost_transform=ct,
            start_sampling=args.start_sampling,
            prob=args.prob,
            random_aug=args.random_aug,
            aug_rmin=args.aug_rmin,
            aug_rmax=args.aug_rmax,
            aug_cmin=args.aug_cmin,
            aug_cmax=args.aug_cmax,
            cgap=args.cgap,
            rstd=args.rstd,
            cstd=args.cstd
        )

        # prompt_dataset = TransitionDataset(data_ls[i],
        #                                 reward_scale=args.reward_scale,
        #                                 cost_scale=args.cost_scale,
        #                                 state_encoder=pretrained_se_ls[i],
        #                                 action_encoder=pretrained_ae_ls[i]
        #                                 )
        prompt_loader = DataLoader(
                                prompt_dataset,
                                batch_size=args.batch_size,
                                pin_memory=True,
                                num_workers=0,
                            )
        promptloader_iter = iter(prompt_loader)
        prompt_dataloader_iter_ls.append(promptloader_iter)

    for step in trange(args.update_steps, desc="Training"):
        # train
        for i in range(len(tasks)):
            batch = next(dataloader_iter_ls[i])
            prompt_batch = next(prompt_dataloader_iter_ls[i])
            prompt_states, prompt_actions, prompt_returns, prompt_costs_return, prompt_time_steps, prompt_mask, prompt_episode_cost, prompt_costs = [
                b.to(args.device).to(torch.float32) for b in prompt_batch
            ]
            states, actions, returns, costs_return, time_steps, mask, episode_cost, costs = [
                b.to(args.device) for b in batch
            ]
            # with torch.no_grad():
            #     encoder_input = torch.cat([prompt_states,prompt_actions,prompt_next_states,prompt_rewards.reshape(-1,1)],dim=-1)
            #     prompt_encoding = encoder(encoder_input, prompt_costs)
            #     if args.prompt_prefix and not args.prompt_concat:
            #         prompt_encoding = prompt_encoding.reshape(1,1,-1).expand(states.shape[0],-1,-1)
            #     else:
            #         prompt_encoding = prompt_encoding.reshape(1,1,-1).expand(states.shape[0],states.shape[1],-1)
            trainer.train_one_step(states, actions, returns, costs_return, prompt_states, prompt_actions, prompt_returns, prompt_costs_return, time_steps, mask,
                                episode_cost, costs, task_names[i], i, task_envs[i])

        # evaluation
        if (step + 1) % args.eval_every == 0 or step == args.update_steps - 1:
            for i in range(len(tasks)):
                prompt_batch = next(prompt_dataloader_iter_ls[i])
                prompt_states, prompt_actions, prompt_returns, prompt_costs_return, prompt_time_steps, prompt_mask, prompt_episode_cost, prompt_costs = [
                    b[0:1].to(args.device).to(torch.float32) for b in prompt_batch
                ]
                average_reward, average_cost = [], []
                log_cost, log_reward, log_len = {}, {}, {}
                for target_return in target_returns[i]:
                    reward_return, cost_return = target_return
                    if args.cost_reverse:
                        assert False
                        # critical step, rescale the return!
                        ret, cost, length = trainer.evaluate(
                            args.eval_episodes, reward_return * args.reward_scale,
                            (args.episode_len - cost_return) * args.cost_scale, i, task_envs[i], episode_lens[i], state_dims[i], action_dims[i], prompt=prompt_encoding)
                    else:
                        ret, cost, length = trainer.evaluate(
                            args.eval_episodes, reward_return * args.reward_scale,
                            cost_return * args.cost_scale, i, task_envs[i], episode_lens[i], state_dims[i], action_dims[i], prompt_states, prompt_actions, prompt_returns, prompt_costs_return)
                    print(task_names[i], reward_return, cost_return, ret, cost)
                    average_cost.append(cost)
                    average_reward.append(ret)

                    name = "c_" + str(int(cost_return)) + "_r_" + str(int(reward_return))
                    log_cost.update({name: cost})
                    log_reward.update({name: ret})
                    log_len.update({name: length})


                logger.store(tab=task_names[i]+"/cost", **log_cost)
                logger.store(tab=task_names[i]+"/ret", **log_reward)
                logger.store(tab=task_names[i]+"/length", **log_len)

            # save the current weight

            logger.save_checkpoint()
            # save the best weight
            # mean_ret = np.mean(average_reward)
            # mean_cost = np.mean(average_cost)
            # if mean_cost < best_cost or (mean_cost == best_cost
            #                              and mean_ret > best_reward):
            #     best_cost = mean_cost
            #     best_reward = mean_ret
            #     best_idx = step
            #     logger.save_checkpoint(suffix="best")
            #     if train_action_encoder:
            #         action_encoder_logger.save_checkpoint(suffix="best")

            # logger.store(tab="train", best_idx=best_idx)
            logger.write(step, display=False)

        else:
            logger.write_without_reset(step)


if __name__ == "__main__":
    train()
