import os
import uuid
import types
from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

import bullet_safety_gym  # noqa
import dsrl
import gymnasium as gym  # noqa
import gym as gym_org
import numpy as np
import pyrallis
import torch
from dsrl.infos import DENSITY_CFG
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from fsrl.utils import WandbLogger
from fsrl.utils import TensorboardLogger
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa

from examples.configs.cdt_configs import CDT_DEFAULT_CONFIG, CDTTrainConfig
from osrl.algorithms import State_AE, Action_AE, inverse_dynamics_model, ActionAETrainer, StateAETrainer
from osrl.algorithms import CDT, CDTTrainer, CDT_with_action_AE, MTCDT, MTCDTTrainer
from osrl.algorithms import SafetyAwareEncoder, MultiHeadDecoder, ContextEncoderTrainer, SimpleMlpEncoder
from osrl.common import SequenceDataset, TransitionDataset
from osrl.common.exp_util import auto_name, seed_all, load_config_and_model


@pyrallis.wrap()
def train(args: CDTTrainConfig):
    tasks = ["OfflinePointButton1Gymnasium-v0","OfflinePointButton2Gymnasium-v0","OfflinePointCircle1Gymnasium-v0","OfflinePointCircle2Gymnasium-v0",
                  "OfflinePointGoal1Gymnasium-v0","OfflinePointGoal2Gymnasium-v0","OfflinePointPush1Gymnasium-v0","OfflinePointPush2Gymnasium-v0",
                  "OfflineHalfCheetahVelocityGymnasium-v0","OfflineHalfCheetahVelocityGymnasium-v1","OfflineHopperVelocityGymnasium-v0","OfflineHopperVelocityGymnasium-v1",
                  "OfflineCarButton1Gymnasium-v0","OfflineCarButton2Gymnasium-v0","OfflineCarCircle1Gymnasium-v0","OfflineCarCircle2Gymnasium-v0",
                  "OfflineCarGoal1Gymnasium-v0","OfflineCarGoal2Gymnasium-v0","OfflineCarPush1Gymnasium-v0","OfflineCarPush2Gymnasium-v0",
                  "OfflineAntVelocityGymnasium-v0","OfflineAntVelocityGymnasium-v1","OfflineSwimmerVelocityGymnasium-v0","OfflineSwimmerVelocityGymnasium-v1",
                  "OfflineWalker2dVelocityGymnasium-v0","OfflineWalker2dVelocityGymnasium-v1"]
    task_names = ["PointButton1","PointButton2","PointCircle1","PointCircle2","PointGoal1","PointGoal2","PointPush1","PointPush2",
                "HalfCheetahVel-v0","HalfCheetahVel-v1","HopperVel-v0","HopperVel-v1",
                "CarButton1","CarButton2","CarCircle1","CarCircle2","CarGoal1","CarGoal2","CarPush1","CarPush2",
                "AntVel-v0","AntVel-v1","SwimmerVel-v0","SwimmerVel-v1","Walker2dVel-v0","Walker2dVel-v1"]
    task_envs = [0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12]
    state_encoder_paths = [
        "logs/PointButtonGymnasium-cost-10/sa_encoder-718f/sa_encoder-718f_state_AE",
        "logs/PointButtonGymnasium-cost-10/sa_encoder-718f/sa_encoder-718f_state_AE",
        "logs/PointCircleGymnasium-cost-10/sa_encoder-e510/sa_encoder-e510_state_AE",
        "logs/PointCircleGymnasium-cost-10/sa_encoder-e510/sa_encoder-e510_state_AE",
        "logs/PointGoalGymnasium-cost-10/sa_encoder-0739/sa_encoder-0739_state_AE",
        "logs/PointGoalGymnasium-cost-10/sa_encoder-0739/sa_encoder-0739_state_AE",
        "logs/PointPushGymnasium-cost-10/sa_encoder-0710/sa_encoder-0710_state_AE",
        "logs/PointPushGymnasium-cost-10/sa_encoder-0710/sa_encoder-0710_state_AE",
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_state_AE",
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_state_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_state_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_state_AE",
        "logs/CarButtonGymnasium-cost-10/sa_encoder-0b6c/sa_encoder-0b6c_state_AE",
        "logs/CarButtonGymnasium-cost-10/sa_encoder-0b6c/sa_encoder-0b6c_state_AE",
        "logs/CarCircleGymnasium-cost-10/sa_encoder-8727/sa_encoder-8727_state_AE",
        "logs/CarCircleGymnasium-cost-10/sa_encoder-8727/sa_encoder-8727_state_AE",
        "logs/CarGoalGymnasium-cost-10/sa_encoder-aa9d/sa_encoder-aa9d_state_AE",
        "logs/CarGoalGymnasium-cost-10/sa_encoder-aa9d/sa_encoder-aa9d_state_AE",
        "logs/CarPushGymnasium-cost-10/sa_encoder-cda6/sa_encoder-cda6_state_AE",
        "logs/CarPushGymnasium-cost-10/sa_encoder-cda6/sa_encoder-cda6_state_AE",
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_state_AE",
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_state_AE",
        "logs/SwimmerVelocityGymnasium-cost-10/sa_encoder-8a6f/sa_encoder-8a6f_state_AE",
        "logs/SwimmerVelocityGymnasium-cost-10/sa_encoder-8a6f/sa_encoder-8a6f_state_AE",
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_state_AE",
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_state_AE"
    ]
    action_encoder_paths = [
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_action_AE",
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_action_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_action_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_action_AE",
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_action_AE",
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_action_AE",
        None,
        None,
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_action_AE",
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_action_AE"
    ]
    episode_lens = [1000,1000,500,500,1000,1000,1000,1000,1000,1000,1000,1000,
                    1000,1000,500,500,1000,1000,1000,1000,1000,1000,1000,1000,1000,1000]
    state_dims = [76,76,28,28,60,60,76,76,17,17,11,11,88,88,40,40,72,72,88,88,27,27,8,8,17,17]
    action_dims = [2,2,2,2,2,2,2,2,6,6,3,3,2,2,2,2,2,2,2,2,8,8,2,2,6,6]
    env_state_dims = [76,28,60,76,17,11,88,40,72,88,27,8,17]
    env_action_dims = [2,2,2,2,6,3,2,2,2,2,8,2,6]
    target_returns = [((40.0, 10), (40.0, 20), (40.0, 40), (40.0, 80)),((40.0, 20), (40.0, 40), (40.0, 80)),((50.0, 20), (52.5, 40), (55.0, 80)),((45.0, 20), (47.5, 40), (50.0, 80)),
                      ((30.0, 20), (30.0, 40), (30.0, 80)),((30.0, 20), (30.0, 40), (30.0, 80)),((15.0, 20), (15.0, 40), (15.0, 80)),((12.0, 20), (12.0, 40), (12.0, 80)),
                      ((3000.0, 20), (3000.0, 40), (3000.0, 80)),((3000.0, 20), (3000.0, 40), (3000.0, 80)),((1750.0, 20), (1750.0, 40), (1750.0, 80)),((1750.0, 20), (1750.0, 40), (1750.0, 80)),
                      ((40.0, 20), (40.0, 40), (40.0, 80)),((40.0, 20), (40.0, 40), (40.0, 80)),((20.0, 20), (22.5, 40), (25.0, 80)),((20.0, 20), (21.0, 40), (22.0, 80)),
                      ((40.0, 20), (40.0, 40), (40.0, 80)),((30.0, 20), (30.0, 40), (30.0, 80)),((15.0, 20), (15.0, 40), (15.0, 80)),((12.0, 20), (12.0, 40), (12.0, 80)),
                      ((2800.0, 20), (2800.0, 40), (2800.0, 80)),((2800.0, 20), (2800.0, 40), (2800.0, 80)),((160.0, 20), (160.0, 40), (160.0, 80)),((160.0, 20), (160.0, 40), (160.0, 80)),
                      ((2800.0, 20), (2800.0, 40), (2800.0, 80)),((2800.0, 20), (2800.0, 40), (2800.0, 80))]
    degs=[0,0,1,1,0,1,0,0,1,1,1,1,0,0,1,1,1,1,0,0,1,1,1,1,1,1]
    max_rewards=[45.0,50.0,65.0,55.0,35.0,35.0,20,15,3000,3000,2000,2000,45,50,30,30,50,35,20,15,3000,3000,250,250,3600,3600]
    max_rew_decreases=[5,10,5,5,5,5,5,3,500,500,300,300,10,10,10,10,5,5,5,3,500,500,50,50,800,800]
    min_rewards=[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]
    
    # tasks=[tasks[-1]]
    # task_names=[task_names[-1]]
    # task_envs=[0]
    # episode_lens=[episode_lens[-1]]
    # state_dims=[state_dims[-1]]
    # action_dims=[action_dims[-1]]
    # env_state_dims=[env_state_dims[-1]]
    # env_action_dims=[env_action_dims[-1]]
    # target_returns=[target_returns[-1]]
    # state_encoder_paths=[state_encoder_paths[-1]]
    # action_encoder_paths=[action_encoder_paths[-1]]

    

    # update config
    args.task = tasks[0]
    cfg, old_cfg = asdict(args), asdict(CDTTrainConfig())
    differing_values = {key: cfg[key] for key in cfg.keys() if cfg[key] != old_cfg[key]}
    cfg = asdict(CDT_DEFAULT_CONFIG[args.task]())
    cfg.update(differing_values)
    args = types.SimpleNamespace(**cfg)
    task_num = len(tasks)

    # print(args.device)
    # assert False
    # setup logger
    default_cfg = asdict(CDT_DEFAULT_CONFIG[args.task]())
    if args.name is None:
        args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
    if args.group is None:
        args.group = "MTCDT" + "-task_num-" + str(task_num)
    # if args.use_prompt:
    #     args.group+="_use_prompt"
    if args.logdir is not None:
        args.logdir = os.path.join(args.logdir, args.group, args.name)
    # logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
    logger = TensorboardLogger(args.logdir, log_txt=True, name=args.name)
    logger.save_config(cfg, verbose=args.verbose)
        
    # set seed
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    # initialize environment
    env_ls=[]
    data_ls=[]
    target_entropy_ls=[]
    for task in tasks[:1]:
        temp_env = gym.make(task)
        temp_env.set_target_cost(args.cost_limit)
        env_ls.append(temp_env)
        temp_data = temp_env.get_dataset()
        data_ls.append(temp_data)
        target_entropy_ls.append(-temp_env.action_space.shape[0])

    cbins, rbins, max_npb, min_npb = None, None, None, None
    if args.density != 1.0:
        assert False
        density_cfg = DENSITY_CFG[args.task + "_density" + str(args.density)]
        cbins = density_cfg["cbins"]
        rbins = density_cfg["rbins"]
        max_npb = density_cfg["max_npb"]
        min_npb = density_cfg["min_npb"]
    for i in range(len(tasks[:1])):
        data_ls[i] = env_ls[i].pre_process_data(data_ls[i],
                                    args.outliers_percent,
                                    args.noise_scale,
                                    args.inpaint_ranges,
                                    args.epsilon,
                                    args.density,
                                    cbins=cbins,
                                    rbins=rbins,
                                    max_npb=max_npb,
                                    min_npb=min_npb)

    # wrapper
    for i in range(len(tasks[:1])):
        temp_env = env_ls[i]
        temp_env = wrap_env(
            env=temp_env,
            reward_scale=args.reward_scale,
        )
        temp_env = OfflineEnvWrapper(temp_env)
        env_ls[i] = temp_env

    state_encoder_ls = []
    action_encoder_ls = []
    pretrained_se_ls = []
    pretrained_ae_ls = []
    for i in range(task_envs[-1]+1):
        # linear only is important
        state_encoder = State_AE(
            state_dim=env_state_dims[i],
            encode_dim=args.state_encode_dim,
            hidden_sizes=args.state_encoder_hidden_sizes,
            # linear_only=True
        )
        state_encoder.to(args.device)
        # decoder linear only is important
        action_encoder = Action_AE(
            action_dim=env_action_dims[i],
            encode_dim=args.action_encode_dim,
            hidden_sizes=args.action_encoder_hidden_sizes,
            require_tanh=False,
            decode_mu_std=True,
            # linear_only=True,
            # decoder_linear_only=True
        )
        action_encoder.to(args.device)
        state_encoder_ls.append(state_encoder)
        action_encoder_ls.append(action_encoder)

    for i in range(len(tasks[:1])):
        senc_cfg, senc_model = load_config_and_model(state_encoder_paths[i], True, device=torch.device("cpu"))
        pretrained_se = State_AE(
            state_dim=state_dims[i],
            encode_dim=senc_cfg["state_encode_dim"],
            hidden_sizes=senc_cfg["state_encoder_hidden_sizes"]
        )
        pretrained_se.load_state_dict(senc_model["model_state"])
        pretrained_se.eval()
        pretrained_se_ls.append(pretrained_se)

        if action_encoder_paths[i] is not None:
            aenc_cfg, aenc_model = load_config_and_model(action_encoder_paths[i], True, device=torch.device("cpu"))
            pretrained_ae = Action_AE(
                action_dim=action_dims[i],
                encode_dim=aenc_cfg["action_encode_dim"],
                hidden_sizes=aenc_cfg["action_encoder_hidden_sizes"]
            )
            pretrained_ae.load_state_dict(aenc_model["model_state"])
            pretrained_ae.eval()
            pretrained_ae_ls.append(pretrained_ae)
        else:
            pretrained_ae_ls.append(None)
    
    enc_cfg, enc_model = load_config_and_model(args.context_encoder_path, False, device=torch.device(args.device))
    enc_cfg = types.SimpleNamespace(**enc_cfg)
    if not enc_cfg.simple_mlp:
        encoder=SafetyAwareEncoder(
            enc_cfg.state_encoding_dim*2+enc_cfg.action_encoding_dim+1,
            enc_cfg.context_encoder_hidden_sizes,
            enc_cfg.context_encoding_dim,
            simple_gate=enc_cfg.simple_gate
            ).to(args.device)
    else:
        encoder=SimpleMlpEncoder(
            enc_cfg.state_encoding_dim*2+enc_cfg.action_encoding_dim+2,
            enc_cfg.context_encoder_hidden_sizes,
            enc_cfg.context_encoding_dim
            ).to(args.device)
    encoder.load_state_dict(enc_model["encoder_state"])
    encoder.eval()
        # if train_action_encoder:
        #     action_encoder.to(args.device)
        #     final_path=args.action_encoder_path.rfind('/')
        #     new_name="action_encoder_after_pretrain"
        #     action_encoder_logdir=args.action_encoder_path[:final_path]+"/"+new_name
        #     action_encoder_logger = TensorboardLogger(action_encoder_logdir, log_txt=True, name=new_name)
        # else:
        #     action_encoder.eval()

    # model & optimizer & scheduler setup
    state_dim = args.state_encode_dim
    if args.prompt_concat:
        state_dim += args.prompt_dim 
    cdt_model = CDT(
        state_dim=state_dim,
        action_dim=args.action_encode_dim,
        max_action=env_ls[0].action_space.high[0],
        embedding_dim=args.embedding_dim,
        seq_len=args.seq_len,
        episode_len=args.episode_len,
        num_layers=args.num_layers,
        num_heads=args.num_heads,
        attention_dropout=args.attention_dropout,
        residual_dropout=args.residual_dropout,
        embedding_dropout=args.embedding_dropout,
        time_emb=args.time_emb,
        use_rew=args.use_rew,
        use_cost=args.use_cost,
        cost_transform=args.cost_transform,
        add_cost_feat=args.add_cost_feat,
        mul_cost_feat=args.mul_cost_feat,
        cat_cost_feat=args.cat_cost_feat,
        action_head_layers=args.action_head_layers,
        cost_prefix=args.cost_prefix,
        stochastic=args.stochastic,
        init_temperature=args.init_temperature,
        target_entropy=target_entropy_ls,
        use_prompt=args.use_prompt,
        prompt_prefix=args.prompt_prefix,
        prompt_concat=args.prompt_concat,
        prompt_dim=args.prompt_dim
    ).to(args.device)
    

    # if train_action_encoder:
    _cfg, model_cdt = load_config_and_model("logs/MTCDT-task_num-26/CDT_prompt_prefixFalse-d034/CDT_prompt_prefixFalse-d034", False)

    model = MTCDT(cdt_model, action_encoder_ls, state_encoder_ls)
    model.to(args.device)
    model.load_state_dict(model_cdt["model_state"])
    # def checkpoint_fn_action():
    #     return {"model_state": action_encoder.state_dict()}
    def checkpoint_fn():
        return {"model_state": model.state_dict()}
    # else:
    #     model = cdt_model
    #     def checkpoint_fn():
    #         return {"model_state": model.state_dict()}
    print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

    # def checkpoint_fn():
    #     return {"model_state": model.state_dict()}

    logger.setup_checkpoint_fn(checkpoint_fn)

    # trainer
    trainer = MTCDTTrainer(model,
                         env_ls,
                         reward_scale=args.reward_scale,
                         cost_scale=args.cost_scale,
                         device=args.device)

    # dataloader_iter_ls=[]
    prompt_dataloader_iter_ls=[]
    for i in range(len(tasks[:1])):
        prompt_dataset = TransitionDataset(data_ls[i],
                                        reward_scale=args.reward_scale,
                                        cost_scale=args.cost_scale,
                                        state_encoder=pretrained_se_ls[i],
                                        action_encoder=pretrained_ae_ls[i]
                                        )
        prompt_loader = DataLoader(
                                prompt_dataset,
                                batch_size=enc_cfg.context_size,
                                pin_memory=True,
                                num_workers=0,
                            )
        promptloader_iter = iter(prompt_loader)
        prompt_dataloader_iter_ls.append(promptloader_iter)

    # for step in trange(args.update_steps, desc="Training"):
    for i in range(len(task[:1])):
        all_unsafe=True
        while all_unsafe:
            prompt_batch = next(prompt_dataloader_iter_ls[i])
            prompt_states, prompt_next_states, prompt_actions, prompt_rewards, prompt_costs, prompt_done = [
                b.to(args.device).to(torch.float32) for b in prompt_batch
            ]
            condition1=prompt_costs>0
            all_unsafe=torch.all(condition1)
        with torch.no_grad():
            encoder_input = torch.cat([prompt_states,prompt_actions,prompt_next_states,prompt_rewards.reshape(-1,1)],dim=-1)
            prompt_encoding = encoder(encoder_input, prompt_costs)
        average_reward, average_cost = [], []
        log_cost, log_reward, log_len = {}, {}, {}
        for target_return in target_returns[i]:
            reward_return, cost_return = target_return
            if args.cost_reverse:
                # critical step, rescale the return!
                ret, cost, length = trainer.evaluate(
                    args.eval_episodes, reward_return * args.reward_scale,
                    (args.episode_len - cost_return) * args.cost_scale, i, task_envs[i], episode_lens[i], state_dims[i], action_dims[i], prompt=prompt_encoding)
            else:
                ret, cost, length = trainer.evaluate(
                    args.eval_episodes, reward_return * args.reward_scale,
                    cost_return * args.cost_scale, i, task_envs[i], episode_lens[i], state_dims[i], action_dims[i], prompt=prompt_encoding)
            print(reward_return, cost_return)
            print(ret, cost, length)
            average_cost.append(cost)
            average_reward.append(ret)

            name = "c_" + str(int(cost_return)) + "_r_" + str(int(reward_return))
            log_cost.update({name: cost})
            log_reward.update({name: ret})
            log_len.update({name: length})


        logger.store(tab=task_names[i]+"/cost", **log_cost)
        logger.store(tab=task_names[i]+"/ret", **log_reward)
        logger.store(tab=task_names[i]+"/length", **log_len)
        logger.save_checkpoint()
        # logger.write(step, display=False)

    # else:
        # logger.write_without_reset(step)
    # for i in range(len(tasks)):
    #     all_unsafe=True
    #     while all_unsafe:
    #         prompt_batch = next(prompt_dataloader_iter_ls[i])
    #         prompt_states, prompt_next_states, prompt_actions, prompt_rewards, prompt_costs, prompt_done = [
    #             b.to(args.device).to(torch.float32) for b in prompt_batch
    #         ]
    #         condition1=prompt_costs>0
    #         all_unsafe=torch.all(condition1)
    #     with torch.no_grad():
    #         encoder_input = torch.cat([prompt_states,prompt_actions,prompt_next_states,prompt_rewards.reshape(-1,1)],dim=-1)
    #         prompt_encoding = encoder(encoder_input, prompt_costs)

    #     total_normalized_ret=[]
    #     total_normalized_cost=[]
    #     rets = [40,40,40,40]
    #     costs = [10,20,40,80]
    #     for target_ret, target_cost in zip(rets, costs):
    #         if args.cost_reverse:
    #             assert False
    #             # critical step, rescale the return!
    #             ret, cost, length = trainer.evaluate(
    #                 args.eval_episodes, reward_return * args.reward_scale,
    #                 (args.episode_len - cost_return) * args.cost_scale, i, task_envs[i], episode_lens[i], state_dims[i], action_dims[i], prompt=prompt_encoding)
    #         else:
    #             ret, cost, length = trainer.evaluate(
    #                 args.eval_episodes, target_ret * args.reward_scale,
    #                 target_cost * args.cost_scale, i, task_envs[i], episode_lens[i], state_dims[i], action_dims[i], prompt=prompt_encoding)
    #         normalized_ret, normalized_cost = env_ls[i].get_normalized_score(ret, cost)
    #         normalized_cost = cost/target_cost
    #         # total_normalized_ret += normalized_ret
    #         # total_normalized_cost += normalized_cost
    #         total_normalized_ret.append(normalized_ret)
    #         total_normalized_cost.append(normalized_cost)
    #         print(
    #             f"Task {task_names[i]}: Target reward {target_ret}, real reward: {ret}, normalized reward: {normalized_ret}; target cost {target_cost}, normalized cost: {normalized_cost}"
    #         )
    #         # logger.store(tab="Target", target_ret=target_ret, target_cost=target_cost)
    #         # logger.store(tab="Result", normalized_reward=normalized_ret, normalized_cost=normalized_cost, real_reward=ret, real_cost=cost)
    #         # logger.write(num, display=False)
    #     total_normalized_ret_res=sum(total_normalized_ret)/len(total_normalized_ret)
    #     total_normalized_cost_res=sum(total_normalized_cost)/len(total_normalized_cost)
    #     # logger.store(tab="Task", task_name=task_names[i])
    #     logger.store(tab="AvgRes", ret=total_normalized_ret_res, cost=total_normalized_cost_res)
    #     for j, (target_return, target_cost) in enumerate(zip(rets, costs)):
    #         logger.store(tab=f"Target_cost_{target_cost}", ret=total_normalized_ret[j], cost=total_normalized_cost[j])
    #     logger.write(i, display=False)


if __name__ == "__main__":
    train()
