import os
import uuid
import types
from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple
import pickle
import copy

import bullet_safety_gym  # noqa
import dsrl
import gymnasium as gym  # noqa
import gym as gym_org
import numpy as np
import pyrallis
import torch
from dsrl.infos import DENSITY_CFG
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from fsrl.utils import WandbLogger
from fsrl.utils import TensorboardLogger
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa

from examples.configs.cdt_configs import CDT_DEFAULT_CONFIG, CDTTrainConfig
from osrl.algorithms import State_AE, Action_AE, inverse_dynamics_model, ActionAETrainer, StateAETrainer
from osrl.algorithms import CDT, CDTTrainer, CDT_with_action_AE, MTCDT, MTCDTTrainer
from osrl.algorithms import SafetyAwareEncoder, MultiHeadDecoder, ContextEncoderTrainer, SimpleMlpEncoder
from osrl.algorithms import LoraLinear, replace_linear_with_lora, load_lora, unload_lora, print_trainable_parameters, merge_lora
from osrl.common import SequenceDataset, TransitionDataset
from osrl.common.exp_util import auto_name, seed_all, load_config_and_model


@pyrallis.wrap()
def train(args: CDTTrainConfig):
    tasks = ["OfflinePointButton1Gymnasium-v0","OfflinePointButton2Gymnasium-v0","OfflinePointCircle1Gymnasium-v0","OfflinePointCircle2Gymnasium-v0",
                  "OfflinePointGoal1Gymnasium-v0","OfflinePointGoal2Gymnasium-v0","OfflinePointPush1Gymnasium-v0","OfflinePointPush2Gymnasium-v0",
                  "OfflineHalfCheetahVelocityGymnasium-v0","OfflineHalfCheetahVelocityGymnasium-v1","OfflineHopperVelocityGymnasium-v0","OfflineHopperVelocityGymnasium-v1",
                  "OfflineCarButton1Gymnasium-v0","OfflineCarButton2Gymnasium-v0","OfflineCarCircle1Gymnasium-v0","OfflineCarCircle2Gymnasium-v0",
                  "OfflineCarGoal1Gymnasium-v0","OfflineCarGoal2Gymnasium-v0","OfflineCarPush1Gymnasium-v0","OfflineCarPush2Gymnasium-v0",
                  "OfflineAntVelocityGymnasium-v0","OfflineAntVelocityGymnasium-v1","OfflineSwimmerVelocityGymnasium-v0","OfflineSwimmerVelocityGymnasium-v1",
                  "OfflineWalker2dVelocityGymnasium-v0","OfflineWalker2dVelocityGymnasium-v1"]
    task_names = ["PointButton1","PointButton2","PointCircle1","PointCircle2","PointGoal1","PointGoal2","PointPush1","PointPush2",
                "HalfCheetahVel-v0","HalfCheetahVel-v1","HopperVel-v0","HopperVel-v1",
                "CarButton1","CarButton2","CarCircle1","CarCircle2","CarGoal1","CarGoal2","CarPush1","CarPush2",
                "AntVel-v0","AntVel-v1","SwimmerVel-v0","SwimmerVel-v1","Walker2dVel-v0","Walker2dVel-v1"]
    task_envs = [0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12]

    state_encoder_paths = [
        "logs/PointButtonGymnasium-cost-10/sa_encoder-718f/sa_encoder-718f_state_AE",
        "logs/PointButtonGymnasium-cost-10/sa_encoder-718f/sa_encoder-718f_state_AE",
        "logs/PointCircleGymnasium-cost-10/sa_encoder-e510/sa_encoder-e510_state_AE",
        "logs/PointCircleGymnasium-cost-10/sa_encoder-e510/sa_encoder-e510_state_AE",
        "logs/PointGoalGymnasium-cost-10/sa_encoder-0739/sa_encoder-0739_state_AE",
        "logs/PointGoalGymnasium-cost-10/sa_encoder-0739/sa_encoder-0739_state_AE",
        "logs/PointPushGymnasium-cost-10/sa_encoder-0710/sa_encoder-0710_state_AE",
        "logs/PointPushGymnasium-cost-10/sa_encoder-0710/sa_encoder-0710_state_AE",
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_state_AE",
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_state_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_state_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_state_AE",
        "logs/CarButtonGymnasium-cost-10/sa_encoder-0b6c/sa_encoder-0b6c_state_AE",
        "logs/CarButtonGymnasium-cost-10/sa_encoder-0b6c/sa_encoder-0b6c_state_AE",
        "logs/CarCircleGymnasium-cost-10/sa_encoder-8727/sa_encoder-8727_state_AE",
        "logs/CarCircleGymnasium-cost-10/sa_encoder-8727/sa_encoder-8727_state_AE",
        "logs/CarGoalGymnasium-cost-10/sa_encoder-aa9d/sa_encoder-aa9d_state_AE",
        "logs/CarGoalGymnasium-cost-10/sa_encoder-aa9d/sa_encoder-aa9d_state_AE",
        "logs/CarPushGymnasium-cost-10/sa_encoder-cda6/sa_encoder-cda6_state_AE",
        "logs/CarPushGymnasium-cost-10/sa_encoder-cda6/sa_encoder-cda6_state_AE",
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_state_AE",
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_state_AE",
        "logs/SwimmerVelocityGymnasium-cost-10/sa_encoder-8a6f/sa_encoder-8a6f_state_AE",
        "logs/SwimmerVelocityGymnasium-cost-10/sa_encoder-8a6f/sa_encoder-8a6f_state_AE",
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_state_AE",
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_state_AE"
    ]
    action_encoder_paths = [
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_action_AE",
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_action_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_action_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_action_AE",
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_action_AE",
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_action_AE",
        None,
        None,
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_action_AE",
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_action_AE"
    ]

    episode_lens = [1000,1000,500,500,1000,1000,1000,1000,1000,1000,1000,1000,
                    1000,1000,500,500,1000,1000,1000,1000,1000,1000,1000,1000,1000,1000]
    state_dims = [76,76,28,28,60,60,76,76,17,17,11,11,88,88,40,40,72,72,88,88,27,27,8,8,17,17]
    action_dims = [2,2,2,2,2,2,2,2,6,6,3,3,2,2,2,2,2,2,2,2,8,8,2,2,6,6]
    env_state_dims = [76,28,60,76,17,11,88,40,72,88,27,8,17]
    env_action_dims = [2,2,2,2,6,3,2,2,2,2,8,2,6]
    target_returns = [((40.0, 20), (40.0, 40), (40.0, 80)),((40.0, 20), (40.0, 40), (40.0, 80)),((50.0, 20), (52.5, 40), (55.0, 80)),((45.0, 20), (47.5, 40), (50.0, 80)),
                      ((30.0, 20), (30.0, 40), (30.0, 80)),((30.0, 20), (30.0, 40), (30.0, 80)),((15.0, 20), (15.0, 40), (15.0, 80)),((12.0, 20), (12.0, 40), (12.0, 80)),
                      ((3000.0, 20), (3000.0, 40), (3000.0, 80)),((3000.0, 20), (3000.0, 40), (3000.0, 80)),((1750.0, 20), (1750.0, 40), (1750.0, 80)),((1750.0, 20), (1750.0, 40), (1750.0, 80)),
                      ((40.0, 20), (40.0, 40), (40.0, 80)),((40.0, 20), (40.0, 40), (40.0, 80)),((20.0, 20), (22.5, 40), (25.0, 80)),((20.0, 20), (21.0, 40), (22.0, 80)),
                      ((40.0, 20), (40.0, 40), (40.0, 80)),((30.0, 20), (30.0, 40), (30.0, 80)),((15.0, 20), (15.0, 40), (15.0, 80)),((12.0, 20), (12.0, 40), (12.0, 80)),
                      ((2800.0, 20), (2800.0, 40), (2800.0, 80)),((2800.0, 20), (2800.0, 40), (2800.0, 80)),((160.0, 20), (160.0, 40), (160.0, 80)),((160.0, 20), (160.0, 40), (160.0, 80)),
                      ((2800.0, 20), (2800.0, 40), (2800.0, 80)),((2800.0, 20), (2800.0, 40), (2800.0, 80))]
    degs=[0,0,1,1,0,1,0,0,1,1,1,1,0,0,1,1,1,1,0,0,1,1,1,1,1,1]
    max_rewards=[45.0,50.0,65.0,55.0,35.0,35.0,20,15,3000,3000,2000,2000,45,50,30,30,50,35,20,15,3000,3000,250,250,3600,3600]
    max_rew_decreases=[5,10,5,5,5,5,5,3,500,500,300,300,10,10,10,10,5,5,5,3,500,500,50,50,800,800]
    min_rewards=[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]

    new_tasks = ["OfflineAntVelocityGymnasium-v2", "OfflineHalfCheetahVelocityGymnasium-v2", "OfflineHopperVelocityGymnasium-v2", 
                 "OfflineSwimmerVelocityGymnasium-v2", "OfflineWalker2dVelocityGymnasium-v2"]
    new_task_names = ["AntVel-v2","HalfCheetahVel-v2","HopperVel-v2","SwimmerVel-v2","Walker2dVel-v2"]
    new_task_envs = [10,4,5,11,12]
    new_task2tasks = [20,8,10,22,24]
    new_episode_lens = [1000,1000,1000,1000,1000]
    new_state_dims = [27,17,11,8,17]
    new_action_dims = [8,6,3,2,6]
    new_target_returns = [((2600.0,10),),((2300.0,10),),((1600.0,10),),((120.0,10),),((2200.0,10),)]
    new_degs=[1,1,1,1,1]
    new_max_rewards=[2600,2300,1600,120,2200]
    new_min_rewards=[1,1,1,1,1]

    data_paths=["finetune_data/OfflineAntVelocityGymnasium-v2.pkl",
                "finetune_data/OfflineHalfCheetahVelocityGymnasium-v2.pkl",
                "finetune_data/OfflineHopperVelocityGymnasium-v2.pkl",
                "finetune_data/OfflineSwimmerVelocityGymnasium-v2.pkl",
                "finetune_data/OfflineWalker2dVelocityGymnasium-v2.pkl"]

    target_task = args.target_task
    # update config
    args.task = new_tasks[target_task]
    cfg, old_cfg = asdict(args), asdict(CDTTrainConfig())
    differing_values = {key: cfg[key] for key in cfg.keys() if cfg[key] != old_cfg[key]}
    cfg = asdict(CDT_DEFAULT_CONFIG[args.task]())
    cfg.update(differing_values)
    args = types.SimpleNamespace(**cfg)
    task_num = len(tasks)

    # print(args.device)
    # assert False
    # setup logger
    default_cfg = asdict(CDT_DEFAULT_CONFIG[args.task]())
    if args.name is None:
        args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
    if args.group is None:
        args.group = "MTCDT" + "-lora-" + new_task_names[target_task]
    # if args.use_prompt:
    #     args.group+="_use_prompt"
    if args.logdir is not None:
        args.logdir = os.path.join(args.logdir, args.group, args.name)
    # logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
    logger = TensorboardLogger(args.logdir, log_txt=True, name=args.name)
    logger.save_config(cfg, verbose=args.verbose)


    path = "logs/MTCDT-task_num-26/CDT_seed1-2d7d/CDT_seed1-2d7d"
    cfg, model_cdt = load_config_and_model(path, False, device=args.device)
        
    # set seed
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    # initialize environment
    env = gym.make(new_tasks[target_task])
    env.set_target_cost(args.cost_limit)
    with open(data_paths[target_task], 'rb') as f:
        data = pickle.load(f)
    env = wrap_env(
        env=env,
        reward_scale=args.reward_scale,
    )
    env = OfflineEnvWrapper(env)
    target_entropy = -env.action_space.shape[0]
    # env_ls=[]
    # data_ls=[]
    # target_entropy_ls=[]
    # for task in tasks:
    #     temp_env = gym.make(task)
    #     temp_env.set_target_cost(args.cost_limit)
    #     env_ls.append(temp_env)
    #     temp_data = temp_env.get_dataset()
    #     data_ls.append(temp_data)
    #     target_entropy_ls.append(-temp_env.action_space.shape[0])

    # cbins, rbins, max_npb, min_npb = None, None, None, None
    # if args.density != 1.0:
    #     assert False
    #     density_cfg = DENSITY_CFG[args.task + "_density" + str(args.density)]
    #     cbins = density_cfg["cbins"]
    #     rbins = density_cfg["rbins"]
    #     max_npb = density_cfg["max_npb"]
    #     min_npb = density_cfg["min_npb"]
    # # for i in range(len(tasks)):
    # for i in range(len(tasks)):
    #     data_ls[i] = env_ls[i].pre_process_data(data_ls[i],
    #                                 args.outliers_percent,
    #                                 args.noise_scale,
    #                                 args.inpaint_ranges,
    #                                 args.epsilon,
    #                                 args.density,
    #                                 cbins=cbins,
    #                                 rbins=rbins,
    #                                 max_npb=max_npb,
    #                                 min_npb=min_npb)

    # wrapper
    # for i in range(len(tasks)):
    #     temp_env = env_ls[i]
    #     temp_env = wrap_env(
    #         env=temp_env,
    #         reward_scale=args.reward_scale,
    #     )
    #     temp_env = OfflineEnvWrapper(temp_env)
    #     env_ls[i] = temp_env

    state_encoder_ls = []
    action_encoder_ls = []
    pretrained_se_ls = []
    pretrained_ae_ls = []
    for i in range(task_envs[-1]+1):
        # linear only is important
        state_encoder = State_AE(
            state_dim=env_state_dims[i],
            encode_dim=args.state_encode_dim,
            hidden_sizes=args.state_encoder_hidden_sizes,
            # linear_only=True
        )
        state_encoder.to(args.device)
        # decoder linear only is important
        action_encoder = Action_AE(
            action_dim=env_action_dims[i],
            encode_dim=args.action_encode_dim,
            hidden_sizes=args.action_encoder_hidden_sizes,
            require_tanh=False,
            decode_mu_std=True,
            # linear_only=True,
            # decoder_linear_only=True
        )
        action_encoder.to(args.device)
        state_encoder_ls.append(state_encoder)
        action_encoder_ls.append(action_encoder)

    for i in range(len(tasks)):
        senc_cfg, senc_model = load_config_and_model(state_encoder_paths[i], True, device=torch.device("cpu"))
        pretrained_se = State_AE(
            state_dim=state_dims[i],
            encode_dim=senc_cfg["state_encode_dim"],
            hidden_sizes=senc_cfg["state_encoder_hidden_sizes"]
        )
        pretrained_se.load_state_dict(senc_model["model_state"])
        pretrained_se.eval()
        pretrained_se_ls.append(pretrained_se)

        if action_encoder_paths[i] is not None:
            aenc_cfg, aenc_model = load_config_and_model(action_encoder_paths[i], True, device=torch.device("cpu"))
            pretrained_ae = Action_AE(
                action_dim=action_dims[i],
                encode_dim=aenc_cfg["action_encode_dim"],
                hidden_sizes=aenc_cfg["action_encoder_hidden_sizes"]
            )
            pretrained_ae.load_state_dict(aenc_model["model_state"])
            pretrained_ae.eval()
            pretrained_ae_ls.append(pretrained_ae)
        else:
            pretrained_ae_ls.append(None)
    
    enc_cfg, enc_model = load_config_and_model(args.context_encoder_path, False, device=torch.device(args.device))
    enc_cfg = types.SimpleNamespace(**enc_cfg)
    if not enc_cfg.simple_mlp:
        encoder=SafetyAwareEncoder(
            enc_cfg.state_encoding_dim*2+enc_cfg.action_encoding_dim+1,
            enc_cfg.context_encoder_hidden_sizes,
            enc_cfg.context_encoding_dim,
            simple_gate=enc_cfg.simple_gate
            ).to(args.device)
    else:
        encoder=SimpleMlpEncoder(
            enc_cfg.state_encoding_dim*2+enc_cfg.action_encoding_dim+2,
            enc_cfg.context_encoder_hidden_sizes,
            enc_cfg.context_encoding_dim
            ).to(args.device)
    encoder.load_state_dict(enc_model["encoder_state"])
    encoder.eval()
        # if train_action_encoder:
        #     action_encoder.to(args.device)
        #     final_path=args.action_encoder_path.rfind('/')
        #     new_name="action_encoder_after_pretrain"
        #     action_encoder_logdir=args.action_encoder_path[:final_path]+"/"+new_name
        #     action_encoder_logger = TensorboardLogger(action_encoder_logdir, log_txt=True, name=new_name)
        # else:
        #     action_encoder.eval()

    # model & optimizer & scheduler setup
    state_dim = args.state_encode_dim
    if args.prompt_concat:
        state_dim += args.prompt_dim 
    cdt_model = CDT(
        state_dim=state_dim,
        action_dim=args.action_encode_dim,
        max_action=env.action_space.high[0],
        embedding_dim=args.embedding_dim,
        seq_len=args.seq_len,
        episode_len=args.episode_len,
        num_layers=args.num_layers,
        num_heads=args.num_heads,
        attention_dropout=args.attention_dropout,
        residual_dropout=args.residual_dropout,
        embedding_dropout=args.embedding_dropout,
        time_emb=args.time_emb,
        use_rew=args.use_rew,
        use_cost=args.use_cost,
        cost_transform=args.cost_transform,
        add_cost_feat=args.add_cost_feat,
        mul_cost_feat=args.mul_cost_feat,
        cat_cost_feat=args.cat_cost_feat,
        action_head_layers=args.action_head_layers,
        cost_prefix=args.cost_prefix,
        stochastic=args.stochastic,
        init_temperature=args.init_temperature,
        target_entropy=[target_entropy],
        use_prompt=args.use_prompt,
        prompt_prefix=args.prompt_prefix,
        prompt_concat=args.prompt_concat,
        prompt_dim=args.prompt_dim
    ).to(args.device)

    model = MTCDT(cdt_model, action_encoder_ls, state_encoder_ls)
    model.to(args.device)
    model.load_state_dict(model_cdt["model_state"])

    # lora_cdt = copy.deepcopy(model.cdt)
    # replace_linear_with_lora(lora_cdt, r=args.lora_rank, alpha=2*args.lora_rank)
    # print_trainable_parameters(lora_cdt) 
    # model.cdt=lora_cdt
    # for sub_module in model.action_ae_ls:
    #     for param in sub_module.parameters():
    #         param.requires_grad_(False)
    # for sub_module in model.state_ae_ls:
    #     for param in sub_module.parameters():
    #         param.requires_grad_(False)
    # print_trainable_parameters(model)
    # model.to(args.device)
    lora_model = copy.deepcopy(model)
    replace_linear_with_lora(lora_model, r=args.lora_rank, alpha=2*args.lora_rank)
    print_trainable_parameters(lora_model)
    model = lora_model
    model.to(args.device)


    def checkpoint_fn():
        # merge_lora(model)
        return {"model_state": model.state_dict()}

    # print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

    # def checkpoint_fn():
    #     return {"model_state": model.state_dict()}

    logger.setup_checkpoint_fn(checkpoint_fn)

    # trainer
    trainer = MTCDTTrainer(model,
                         [env],
                         logger=logger,
                         learning_rate=args.learning_rate,
                         weight_decay=args.weight_decay,
                         betas=args.betas,
                         clip_grad=args.clip_grad,
                         lr_warmup_steps=args.lr_warmup_steps,
                         reward_scale=args.reward_scale,
                         cost_scale=args.cost_scale,
                         loss_cost_weight=args.loss_cost_weight,
                         loss_state_weight=args.loss_state_weight,
                         cost_reverse=args.cost_reverse,
                         no_entropy=args.no_entropy,
                         device=args.device)

    ct = lambda x: 70 - x if args.linear else 1 / (x + 10)

    # dataloader_iter_ls=[]
    # prompt_dataloader_iter_ls=[]
    # for i in range(len(tasks)):
    dataset = SequenceDataset(
        data,
        seq_len=args.seq_len,
        reward_scale=args.reward_scale,
        cost_scale=args.cost_scale,
        deg=new_degs[target_task],
        pf_sample=False,
        max_rew_decrease=args.max_rew_decrease,
        beta=args.beta,
        augment_percent=0,
        cost_reverse=args.cost_reverse,
        max_reward=new_max_rewards[target_task],
        min_reward=new_min_rewards[target_task],
        pf_only=args.pf_only,
        rmin=args.rmin,
        cost_bins=args.cost_bins,
        npb=args.npb,
        cost_sample=args.cost_sample,
        cost_transform=ct,
        start_sampling=args.start_sampling,
        prob=args.prob,
        random_aug=args.random_aug,
        aug_rmin=args.aug_rmin,
        aug_rmax=args.aug_rmax,
        aug_cmin=args.aug_cmin,
        aug_cmax=args.aug_cmax,
        cgap=args.cgap,
        rstd=args.rstd,
        cstd=args.cstd
    )

    trainloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        pin_memory=True,
        num_workers=0,
    )
    trainloader_iter = iter(trainloader)
    # dataloader_iter_ls.append(trainloader_iter)

    prompt_dataset = TransitionDataset(data,
                                    reward_scale=args.reward_scale,
                                    cost_scale=args.cost_scale,
                                    state_encoder=pretrained_se_ls[new_task2tasks[target_task]],
                                    action_encoder=pretrained_ae_ls[new_task2tasks[target_task]]
                                    )
    prompt_loader = DataLoader(
                            prompt_dataset,
                            batch_size=enc_cfg.context_size,
                            pin_memory=True,
                            num_workers=0,
                        )
    promptloader_iter = iter(prompt_loader)

    for step in trange(args.update_steps, desc="Training"):
        batch = next(trainloader_iter)
        all_unsafe=True
        while all_unsafe:
            prompt_batch = next(promptloader_iter)
            prompt_states, prompt_next_states, prompt_actions, prompt_rewards, prompt_costs, prompt_done = [
                b.to(args.device).to(torch.float32) for b in prompt_batch
            ]
            condition1=prompt_costs>0
            all_unsafe=torch.all(condition1)
        states, actions, returns, costs_return, time_steps, mask, episode_cost, costs = [
            b.to(args.device) for b in batch
        ]
        with torch.no_grad():
            encoder_input = torch.cat([prompt_states,prompt_actions,prompt_next_states,prompt_rewards.reshape(-1,1)],dim=-1)
            prompt_encoding = encoder(encoder_input, prompt_costs)
            if args.prompt_prefix and not args.prompt_concat:
                prompt_encoding = prompt_encoding.reshape(1,1,-1).expand(states.shape[0],-1,-1)
            else:
                prompt_encoding = prompt_encoding.reshape(1,1,-1).expand(states.shape[0],states.shape[1],-1)
        trainer.train_one_step(states, actions, returns, costs_return, time_steps, mask,
                            episode_cost, costs, new_task_names[target_task], 0, new_task_envs[target_task], prompt_encoding)

        # evaluation
        if (step + 1) % args.eval_every == 0 or step == args.update_steps - 1:
            all_unsafe=True
            while all_unsafe:
                prompt_batch = next(promptloader_iter)
                prompt_states, prompt_next_states, prompt_actions, prompt_rewards, prompt_costs, prompt_done = [
                    b.to(args.device).to(torch.float32) for b in prompt_batch
                ]
                condition1=prompt_costs>0
                all_unsafe=torch.all(condition1)
            with torch.no_grad():
                encoder_input = torch.cat([prompt_states,prompt_actions,prompt_next_states,prompt_rewards.reshape(-1,1)],dim=-1)
                prompt_encoding = encoder(encoder_input, prompt_costs)
            average_reward, average_cost = [], []
            log_cost, log_reward, log_len = {}, {}, {}
            for target_return in new_target_returns[target_task]:
                reward_return, cost_return = target_return
                if args.cost_reverse:
                    # critical step, rescale the return!
                    ret, cost, length = trainer.evaluate(
                        args.eval_episodes, reward_return * args.reward_scale,
                        (args.episode_len - cost_return) * args.cost_scale, 0, new_task_envs[target_task], new_episode_lens[target_task], new_state_dims[target_task], new_action_dims[target_task], prompt=prompt_encoding)
                else:
                    ret, cost, length = trainer.evaluate(
                        args.eval_episodes, reward_return * args.reward_scale,
                        cost_return * args.cost_scale, 0, new_task_envs[target_task], new_episode_lens[target_task], new_state_dims[target_task], new_action_dims[target_task], prompt=prompt_encoding)
                print(new_task_names[target_task], reward_return, cost_return, ret, cost)
                average_cost.append(cost)
                average_reward.append(ret)

                name = "c_" + str(int(cost_return)) + "_r_" + str(int(reward_return))
                log_cost.update({name: cost})
                log_reward.update({name: ret})
                log_len.update({name: length})


            logger.store(tab="cost", **log_cost)
            logger.store(tab="ret", **log_reward)
            logger.store(tab="length", **log_len)

            # save the current weight

            logger.save_checkpoint()
            # save the best weight
            # mean_ret = np.mean(average_reward)
            # mean_cost = np.mean(average_cost)
            # if mean_cost < best_cost or (mean_cost == best_cost
            #                              and mean_ret > best_reward):
            #     best_cost = mean_cost
            #     best_reward = mean_ret
            #     best_idx = step
            #     logger.save_checkpoint(suffix="best")
            #     if train_action_encoder:
            #         action_encoder_logger.save_checkpoint(suffix="best")

            # logger.store(tab="train", best_idx=best_idx)
            logger.write(step, display=False)

        else:
            logger.write_without_reset(step)


if __name__ == "__main__":
    train()
