import argparse
import functools
import gc
import os
import pathlib
import sys



os.environ["MUJOCO_GL"] = "osmesa"

import numpy as np
import ruamel.yaml as yaml

sys.path.append(str(pathlib.Path(__file__).parent))

import exploration as expl
import models
import tools
import envs.wrappers as wrappers
from parallel import Parallel, Damy

import torch
from torch import nn
from torch import distributions as torchd
#[todo] start
import concurrent.futures
import networks
import copy
import json
import gym
import shutil
import math
import atexit
import signal
import sys
import traceback

already_called = False



#[todo] end
to_np = lambda x: x.detach().cpu().numpy()


class Dreamer(nn.Module):
    def __init__(self, obs_space, act_space, config, logger, dataset):
        super(Dreamer, self).__init__()
        self._config = config
        self._logger = logger
        self._should_log = tools.Every(config.log_every)
        batch_steps = config.batch_size * config.batch_length
        #[todo] start
        if config.train_env_name_list is not None:
            # if config.wm_with_moe and config.expert_type == "RSSM_append":
            #     config.adapt_train_ratio = False
            if config.adapt_train_ratio:
                config.train_ratio = config.train_ratio / len(config.train_env_name_list)

            if self._config.actor_train_seperate_interval == 0:
                self.start_all_train = 0
            else:
                self.start_all_train = int(self._config.actor_train_seperate_interval * ((self._config.steps / self._config.actor_train_seperate_interval) // len(self._config.train_env_name_list)) * len(self._config.train_env_name_list)) #这里的interval已经*过env_nums了，总steps也已经除以过action_repeat了
            print(f"self.start_all_train:{self.start_all_train}")
            # print(f"actor_train_seperate_interval:{self._config.actor_train_seperate_interval}")
        #[todo] end
        self._should_train = tools.Every(batch_steps / config.train_ratio)
        self._should_pretrain = tools.Once()
        self._should_reset = tools.Every(config.reset_every)
        self._should_expl = tools.Until(int(config.expl_until / config.action_repeat))
        self._metrics = {}
        # this is update step
        self._step = logger.step // config.action_repeat
        self._update_count = 0
        self._dataset = dataset
        self._wm = models.WorldModel(obs_space, act_space, self._step, config)
        #[todo] start
        self.video_pred_idx = 0
        self._act_space = act_space
        self.inference_actors = None
        if config.multi_actor and config.multi_actor_mode == "distinct":
            self._task_behavior = nn.ModuleDict({key: models.ImagBehavior(config, self._wm) for key in config.train_env_name_list}).to(config.device)
            if config.actor_inference_mode == "seperate":
                print(f"config.actor_inference_mode: seperate")
                self.inference_actors = {key: networks.RandomActor(act_space,config) for key in config.train_env_name_list}
                if not config.actor_train_seperate_all_random:
                    print(f"config.actor_train_seperate_all_random: False")
                    self.inference_actors_random = [True] * len(config.train_env_name_list)
        else:
        #[todo] end
            self._task_behavior = models.ImagBehavior(config, self._wm)
        if (
            config.compile and os.name != "nt"
        ):  # compilation is not supported on windows
            self._wm = torch.compile(self._wm)
            if config.multi_actor and config.multi_actor_mode == "distinct":  # [todo]
                self._task_behavior = nn.ModuleDict({key: torch.compile(self._task_behavior[key]) for key in self._task_behavior.keys()}).to(config.device)
            else:
                self._task_behavior = torch.compile(self._task_behavior)
        reward = lambda f, s, a: self._wm.heads["reward"](f).mean()
        if config.multi_actor and config.multi_actor_mode == "distinct": #[todo]
            self._expl_behavior = dict(
                greedy=lambda: self._task_behavior,
                random=lambda: nn.ModuleDict({key: expl.Random(config, act_space) for key in config.train_env_name_list}),
                plan2explore=lambda: nn.ModuleDict({key: expl.Plan2Explore(config, self._wm, reward) for key in config.train_env_name_list}),
            )[config.expl_behavior]().to(self._config.device)

        else:
            self._expl_behavior = dict(
                greedy=lambda: self._task_behavior,
                random=lambda: expl.Random(config, act_space),
                plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward),
            )[config.expl_behavior]().to(self._config.device)


    def __call__(self, obs, reset, state=None, training=True):
        step = self._step
        # [todo] start
        if self._config.multi_actor_mode == "distinct" and self._config.actor_inference_mode == "seperate":
            idx = self.choose_env(interval=self._config.actor_train_seperate_interval)[1]
            if idx is not None:
                if not self._config.actor_train_seperate_all_random:
                    for i in range(0, idx+1):
                        if self.inference_actors_random[i]:
                            self.change_inference_actor(i, to_random=False)
                else:
                    if self.inference_actors_random[idx]:
                        self.change_inference_actor(idx, to_random=False)
                    if idx > 0:
                        self.change_inference_actor(idx-1, to_random=True)
                    if idx == 0 and self.inference_actors_random[-1] == False:
                        self.change_inference_actor(-1, to_random=True)
            else: #到达最后的统一训练部分了
                for i in range(0, len(self.inference_actors)):
                    if self.inference_actors_random[i]:
                        self.change_inference_actor(i, to_random=False)
        if self._config.task_grouping:
            if self._step % self._config.actor_train_seperate_interval <= len(self._config.train_env_name_list): #防止刚好越过去，错过当前的task_grouping
                self.task_grouping()
        # [todo] end
        if training:
            steps = (
                self._config.pretrain
                if self._should_pretrain()
                else self._should_train(step)
            )
            for _ in range(steps):
                if not isinstance(self._dataset, dict):
                    train_data = next(self._dataset)
                else:  # [todo]
                    train_data = tools.sample_merged_data(self._config.train_env_name_list, self._dataset, obs_filter=self._config.obs_filter)
                # print(f"Start agent._train")
                self._train(train_data, total_update=steps) #[todo]
                self._update_count += 1
                self._metrics["update_count"] = self._update_count
            if self._should_log(step):
                for name, values in self._metrics.items():
                    self._logger.scalar(name, float(np.mean(values)))
                    self._metrics[name] = []
                if self._config.video_pred_log:
                    if not isinstance(self._dataset, dict):
                        # print(f"self._dataset not dict")
                        video_data = next(self._dataset)
                        env_name = self._config.task.split("_", 1)[1]#[todo]
                    else: #[todo]
                        env_name = self._config.train_env_name_list[self.video_pred_idx]
                        video_data = next(self._dataset[env_name])
                        self.video_pred_idx = (self.video_pred_idx + 1) % len(self._config.train_env_name_list)
                    openl = self._wm.video_pred(video_data, env_name=env_name) #[todo]
                    self._logger.video("train_openl", to_np(openl))
                self._logger.write(fps=True)

        policy_output, state = self._policy(obs, state, training)

        if training:
            self._step += len(reset)
            self._logger.step = self._config.action_repeat * self._step

        return policy_output, state

    #[todo] start
    def task_grouping(self):
        gradient_set = {}
        gradient_list = []
        for env_name in self._config.train_env_name_list:
            avg_gradient = []
            for _ in range(10):
                data = next(self._dataset[env_name])
                gradient,shapes = self._wm.get_input_gradients(data)
                avg_gradient.append(gradient)
                # print(f"grads:{grads}")
            avg_gradient = torch.stack(avg_gradient)
            avg_gradient = torch.mean(avg_gradient, dim=0)
            gradient_set[env_name] = avg_gradient
            gradient_list.append(avg_gradient.cpu().numpy())
            # print(f"avg_gradient:{avg_gradient}")
        gradient_list = np.array(gradient_list)
        gradient_list = torch.tensor(gradient_list)
        cols_all_zeros = torch.all(gradient_list == 0, dim=0)

        # 创建一个布尔掩码，选择非全0的列
        mask = ~cols_all_zeros

        def get_avg_gradient(gradient_list):
            harmo_gradient = torch.zeros_like(gradient_list[0])
            for i in range(len(gradient_list)):
                harmo_gradient += gradient_list[i]
            harmo_gradient /= len(gradient_list)
            return harmo_gradient

        # 使用布尔掩码选择非全0的列
        gradient_list = gradient_list[:, mask]
        avg_gradient = get_avg_gradient(gradient_list)
        conflicts = []
        for i in gradient_list:
            conflicts.append(avg_gradient * i)
        conflicts = torch.stack(conflicts).cuda()
        min_std = 10000
        final_group = None

        def run_kmeans_multiple_times(data, K, num_runs=40, max_iters=100, tol=1e-4):
            best_centroids = None
            best_labels = None
            best_distance = float('inf')

            def kmeans(data, K, centroids, max_iters=100, tol=1e-4):
                for i in range(max_iters):
                    # 计算每个点到每个中心的距离
                    distances = torch.cdist(data, centroids)

                    # 找到最近的中心点
                    labels = torch.argmin(distances, dim=1)

                    # 更新中心点
                    new_centroids = torch.stack([data[labels == k].mean(dim=0) for k in range(K)])
                    # print(new_centroids.shape)
                    # print(centroids.shape)
                    # if new_centroids.shape!=centroids.shape:
                    #     continue
                    # 检查中心点是否收敛
                    if torch.all(torch.abs(new_centroids - centroids) < tol):
                        break

                    centroids = new_centroids

                return centroids, labels

            for _ in range(num_runs):
                # 使用 K-means++ 初始化中心点
                indices = torch.randperm(data.size(0))[:K]
                # print(indices)
                centroids = data[indices]  # kmeans_plusplus_init(data, K)##data[indices]##
                centroids = centroids.cuda()  # 将中心点移动到GPU
                # print(centroids)
                # 运行 K-means
                # print(centroids.shape)
                final_centroids, final_labels = kmeans(data, K, centroids, max_iters, tol)
                # print(final_labels)
                # 计算总距离
                distances = torch.cdist(data, final_centroids)
                total_distance = distances.min(dim=1)[0].sum().item()

                # 更新最优结果
                if total_distance < best_distance:
                    best_distance = total_distance
                    best_centroids = final_centroids
                    best_labels = final_labels

            return best_centroids, best_labels

        group_num = self._config.group_num

        for _ in range(40):
            final_centroids, final_labels = run_kmeans_multiple_times(conflicts.cuda(), group_num)
            l = [[] for _ in range(group_num)]
            for i in range(len(final_labels)):
                l[final_labels[i]].append(i)
            ll = []
            for i in l:
                ll.append(len(i))
            std = np.var(ll)
            if std < min_std:
                min_std = std
                final_group = copy.deepcopy(l)

        for i in final_group:
            print(len(i), i)
        dic = {}
        for i in range(len(final_group)):
            dic[f'group_{i}'] = [self._config.train_env_name_list[j] for j in final_group[i]]
        logdir = pathlib.Path(self._config.logdir).expanduser()
        save_path = f'{logdir}/task_grouping_output'
        os.makedirs(save_path, exist_ok=True)

        with open(f'{save_path}/gradient_grouped_task_{group_num}_seed{self._config.seed}_step{self._step}.json', 'w') as file:
            json.dump(dic, file, indent=2)


    def change_inference_actor(self, env_idx, to_random):
        """
        to_random=True: actor to random actor
        to_random=False: random actor to actor
        """
        env_name = self._config.train_env_name_list[env_idx]
        if to_random:
            self.inference_actors[env_name] = networks.RandomActor(self._act_space,self._config)
            self.inference_actors_random[env_idx] = True
        else:
            self.inference_actors[env_name] = self._task_behavior[env_name].actor
            self.inference_actors_random[env_idx] = False


    def _process_env_policy(self, env_idx, name, feat, training, action=None, logprob=None):
        """
        处理每个环境的计算，更新 action 和 logprob。
        """
        ## 轮询式(env0,env1,env2,env0,env1,env2...)
        # indices = torch.arange(feat.size(0))[torch.arange(feat.size(0)) % len(self._config.train_env_name_list) == env_idx]
        ## 整块式(env0,env0,env1,env1,env2,env2...)
        env_size = feat.size(0) // len(self._config.train_env_name_list)
        start_idx = env_idx * env_size
        end_idx = start_idx + env_size
        indices = torch.arange(start_idx, end_idx)
        # print(f"indices:{indices}.len:{len(indices)}")
        if len(indices) > 0:
            if not self._config.actor_inference_mode == "seperate":
                if not training:
                    temp_actor = self._task_behavior[name].actor(feat[indices])
                    action_result = temp_actor.mode()
                elif self._should_expl(self._step):
                    temp_actor = self._expl_behavior[name].actor(feat[indices])
                    action_result = temp_actor.sample()
                else:
                    temp_actor = self._task_behavior[name].actor(feat[indices])
                    action_result = temp_actor.sample()
            else:
                if not training:
                    # print(f"name:{name}")
                    # print(f"self.inference_actors:{self.inference_actors}")
                    temp_actor = self.inference_actors[name](feat[indices])
                    action_result = temp_actor.mode()
                elif self._should_expl(self._step):
                    if self._config.expl_behavior != "greedy":
                        raise NotImplementedError("Not implemented yet")
                    temp_actor = self.inference_actors[name](feat[indices])
                    action_result = temp_actor.sample()
                else:
                    temp_actor = self.inference_actors[name](feat[indices])
                    action_result = temp_actor.sample()
                    
            logprob_result = temp_actor.log_prob(action_result).to(self._config.device)
            action_result = action_result.to(self._config.device)
            # print(f"indices:{indices}")

            if action is not None:
                action[indices] = action_result
                logprob[indices] = logprob_result
            # print(f"action_result.shape:{action_result.shape}")

            return action_result,logprob_result



    def _parallelize_policy(self, feat, training):
        """
        使用并行化来处理多个环境的计算
        """
        action = torch.zeros(feat.size(0), self._config.num_actions).to(self._config.device)
        logprob = torch.zeros(feat.size(0)).to(self._config.device)
        #
        # # 使用线程池来并行处理每个环境
        # with concurrent.futures.ThreadPoolExecutor() as executor:
        #     futures = []
        #     for env_idx, name in enumerate(self._config.train_env_name_list):
        #         # 提交每个环境的计算任务
        #         futures.append(executor.submit(self._process_env_policy, env_idx, name, feat, training, action, logprob))
        #
        #     # 等待所有任务完成
        #     concurrent.futures.wait(futures)
        #     if "cuda" in self._config.device:
        #         torch.cuda.synchronize()

        # action_list, logprob_list = [], []

        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [
                executor.submit(self._process_env_policy, env_idx, name, feat, training)
                for env_idx, name in enumerate(self._config.train_env_name_list)
            ]
            concurrent.futures.wait(futures)
            results = [f.result() for f in futures]
            for env_idx, (a, lp) in enumerate(results):
                # indices = torch.arange(feat.size(0))[
                #     torch.arange(feat.size(0)) % len(self._config.train_env_name_list) == env_idx]

                env_size = feat.size(0) // len(self._config.train_env_name_list)
                start_idx = env_idx * env_size
                end_idx = start_idx + env_size
                indices = torch.arange(start_idx, end_idx)

                action[indices] = a
                logprob[indices] = lp
        if "cuda" in self._config.device:
            torch.cuda.synchronize()  # 确保所有GPU计算完成
        # action = torch.stack(action_list)
        # logprob = torch.stack(logprob_list)
        return action, logprob

    def _parallelize_with_cuda_streams_policy(self, feat, training):
        """
        使用多个 CUDA Stream 异步并行计算不同环境的 actor 输出
        适合单 GPU + 强算力 + 弱 CPU 的情况
        """
        num_envs = len(self._config.train_env_name_list)
        device = self._config.device

        # 初始化输出张量
        action = torch.zeros(feat.size(0), self._config.num_actions, device=device)
        logprob = torch.zeros(feat.size(0), device=device)

        # 创建每个环境对应的 CUDA stream
        streams = [torch.cuda.Stream(device=device) for _ in range(num_envs)]

        # 为每个样本计算属于哪个环境
        # env_ids = torch.arange(feat.size(0), device=device) % num_envs
        env_size = feat.size(0) // len(self._config.train_env_name_list)  # 每个环境的样本数（均分）
        # 生成：[0,0,...,1,1,...,2,2,...]（每个env_id重复env_size次）
        env_ids = torch.arange(num_envs, device=device).repeat_interleave(env_size)

        # 并行提交到多个 GPU 流
        for env_idx, (name, stream) in enumerate(zip(self._config.train_env_name_list, streams)):
            mask = (env_ids == env_idx)
            if not mask.any():
                continue

            # 在独立 CUDA Stream 中执行 forward
            with torch.cuda.stream(stream):
                feat_env = feat[mask]
                if not self._config.actor_inference_mode == "seperate":
                    if not training:
                        actor = self._task_behavior[name].actor(feat_env)
                        action_env = actor.mode()
                    elif self._should_expl(self._step):
                        actor = self._expl_behavior[name].actor(feat_env)
                        action_env = actor.sample()
                    else:
                        actor = self._task_behavior[name].actor(feat_env)
                        action_env = actor.sample()

                else:
                    if not training:
                        actor = self.inference_actors[name](feat_env)
                        action_env = actor.mode()
                    elif self._should_expl(self._step):
                        if self._config.expl_behavior != "greedy":
                            raise NotImplementedError("Not implemented yet")
                        actor = self.inference_actors[name](feat_env)
                        action_env = actor.sample()
                    else:
                        actor = self.inference_actors[name](feat_env)
                        action_env = actor.sample()

                logprob_env = actor.log_prob(action_env)

                # 异步写回结果（stream 内部是安全的）
                action[mask] = action_env
                logprob[mask] = logprob_env

        # 等待所有 CUDA 流完成（同步）
        torch.cuda.synchronize(device=device)

        return action, logprob

    def choose_env(self, interval=None):
        # suite, task = self._config.task.split("_", 1)
        if interval is None:
            raise NotImplementedError("Not implemented yet")
        if interval == 0:
            return "all", None
        if interval == self._config.actor_train_seperate_interval:
            threshold = self.start_all_train
        else:
            temp_step = interval * len(self._config.train_env_name_list)
            threshold = temp_step * ((self._config.steps / temp_step) // len(self._config.train_env_name_list))
        if self._step >= threshold:
            return "all", None
        temp_idx = self._step // interval
        idx =int(temp_idx % len(self._config.train_env_name_list))
        env_name = self._config.train_env_name_list[idx]
        # print(f"env_name in choose_env:{env_name}")
        return env_name, idx
    #[todo] end

    def _policy(self, obs, state, training):
        if state is None:
            latent = action = None
        else:
            latent, action = state
        obs = self._wm.preprocess(obs)
        embed = self._wm.encoder(obs)
        # print(f"obs['image'].shape:{obs['image'].shape}")
        # if state:
        #     print(f"state[1].shape:{state[1].shape}")
        #     print(f"state[0]['stoch'].shape:{state[0]['stoch'].shape}")
        latent, _ = self._wm.dynamics.obs_step(latent, action, embed, obs["is_first"])
        if self._config.eval_state_mean:
            latent["stoch"] = latent["mean"]
        feat = self._wm.dynamics.get_feat(latent)
        #[todo] start
        if self._config.multi_actor and self._config.multi_actor_mode == "distinct":
            if self._config.multi_actor_sample == "thread":
                action, logprob = self._parallelize_policy(feat, training)  # 调用并行化方法
            elif self._config.multi_actor_sample == "stream":
                action, logprob = self._parallelize_with_cuda_streams_policy(feat, training)
            else:
                action = torch.zeros(feat.size(0), self._config.num_actions).to(self._config.device)
                logprob = torch.zeros(feat.size(0)).to(self._config.device)
                for env_idx, name in enumerate(self._config.train_env_name_list):
                    #批量分环境处理数据(适配envs数目是train_env_name_list数目两倍及以上的情况)
                    self._process_env_policy(env_idx,name,feat,training,action,logprob)
        else:
        #[todo] end
            if not training:
                actor = self._task_behavior.actor(feat)
                action = actor.mode()
            elif self._should_expl(self._step):
                actor = self._expl_behavior.actor(feat)
                action = actor.sample()
            else:
                actor = self._task_behavior.actor(feat)
                action = actor.sample()
            logprob = actor.log_prob(action)
        latent = {k: v.detach() for k, v in latent.items()}
        action = action.detach()
        if self._config.actor["dist"] == "onehot_gumble":
            action = torch.one_hot(
                torch.argmax(action, dim=-1), self._config.num_actions
            )

        policy_output = {"action": action, "logprob": logprob}
        state = (latent, action)
        return policy_output, state


    def _train(self, data, total_update=None): #[todo]
        metrics = {}
        #[todo] start
        if self._config.wm_with_moe and self._config.expert_type == "RSSM":
            post, context, mets = self._wm._train_router_stage(data, step=self._step, total_update=total_update)
        #[todo] end
        else:
            post, context, mets = self._wm._train(data, step=self._step, total_update=total_update) #[todo]
        metrics.update(mets)
        start = post
        reward = lambda f, s, a, env_name=None: self._wm.heads["reward"](   #[todo]
            self._wm.dynamics.get_feat(s), env_name=env_name
        ).mode()
        if self._config.multi_actor and self._config.multi_actor_mode == "distinct": #[todo]
            num_splits = len(self._config.train_env_name_list)
            # 在第0维拆分
            splits = {k: start[k].chunk(num_splits, dim=0) for k in start.keys()}
            # 组装成新字典，结构和start一致
            start_process = {
                name: {k: splits[k][i] for k in start.keys()}
                for i, name in enumerate(self._config.train_env_name_list)
            }
            current_env_name = self.choose_env(interval=self._config.actor_train_seperate_interval)[0]
            if self._config.actor_train_mode == "seperate" and current_env_name != "all":
                temp_task_metrics = {f"actor_{current_env_name}_{inner}": array for
                                     inner, array in
                                     self._task_behavior[current_env_name]._train(start_process[current_env_name], reward, env_name=current_env_name)[-1].items()} #[todo]
                # print(f"temp_task_metrics:{temp_task_metrics}")
                metrics.update(temp_task_metrics)
            else:
                #  create a CUDA stream per behavior (if using CPU only, skip streams)
                if "cuda" in self._config.device and self._config.multi_actor_train == "stream":
                    streams = {name: torch.cuda.Stream(device=self._config.device) for name in self._config.train_env_name_list}
                else:
                    streams = None
                if streams is not None:
                    for n in self._config.train_env_name_list:
                        with torch.cuda.stream(streams[n]):
                            temp_mets = self._task_behavior[n]._train(start_process[n], reward, env_name=n)[-1]
                            metrics.update({f"actor_{n}_{inner}": temp_mets[inner] for inner in temp_mets.keys()})
                    # wait for all streams to complete
                    torch.cuda.synchronize(self._config.device)
                else:
                    temp_task_metrics = {f"actor_{n}_{inner}": array for n in self._config.train_env_name_list for inner, array in self._task_behavior[n]._train(start_process[n], reward, env_name=n)[-1].items()}
                    # print(f"temp_task_metrics:{temp_task_metrics}")
                    metrics.update(temp_task_metrics)
                if self._config.expl_behavior != "greedy":
                    raise NotImplementedError("context,data not segment")
                    temp_expl_metrics = {f"expl_{n}":self._expl_behavior[n]._train(start_process[n], context, data)[-1] for n in self._config.train_env_name_list}
                    metrics.update(temp_expl_metrics)
        else:
            metrics.update(self._task_behavior._train(start, reward)[-1]) #[todo]
            if self._config.expl_behavior != "greedy":
                # if self._config.multi_actor:
                #     mets = self._expl_behavior[env_name].train(start, context, data, env_name=env_name)[-1] #[todo]
                # else:
                mets = self._expl_behavior.train(start, context, data)[-1] #[todo]
                metrics.update({"expl_" + key: value for key, value in mets.items()})
        for name, value in metrics.items():
            if not name in self._metrics.keys():
                self._metrics[name] = [value]
            else:
                self._metrics[name].append(value)


def count_steps(folder):
    # [todo] start
    if isinstance(folder,dict):
        step = 0
        for dir in folder.values():
            step += sum(int(str(n).split("-")[-1][:-4]) - 1 for n in dir.glob("*.npz"))
        return step
    # [todo] end
    return sum(int(str(n).split("-")[-1][:-4]) - 1 for n in folder.glob("*.npz"))

#[todo] start
def copy_npz_files(src_folder, dest_folder, target_steps):
    # 获取源文件夹中所有的 npz 文件
    npz_files = list(src_folder.glob("*.npz"))

    # 按照文件名的时间戳部分排序（文件名的前14个字符是时间戳）
    npz_files.sort(key=lambda x: x.stem.split('-')[0])  # 按照时间戳排序
    # print(f"npz_files:{npz_files}")

    copied_files = []
    current_steps = 0

    for npz_file in npz_files:
        # 从文件名中提取步数（文件名的最后部分-501）(stem本身就会去掉.npz所以不需要[:-4]了)
        steps_in_file = int(npz_file.stem.split('-')[-1]) - 1

        # 计算当前总步数
        if current_steps + steps_in_file >= target_steps:
            # 如果超过目标步数，计算需要复制的部分
            remaining_steps = target_steps - current_steps
            if remaining_steps > 0:
                # 复制一部分文件
                shutil.copy(npz_file, dest_folder)
                copied_files.append(npz_file)
            break
        else:
            # 如果没有超过目标步数，直接复制文件
            shutil.copy(npz_file, dest_folder)
            copied_files.append(npz_file)
            current_steps += steps_in_file

    return copied_files  # 返回已复制的文件列表
#[todo] end



def make_dataset(episodes, config):
    generator = tools.sample_episodes(episodes, config.batch_length)
    dataset = tools.from_generator(generator, config.batch_size)
    return dataset


def make_env(config, mode, id):
    suite, task = config.task.split("_", 1)
    print(f"task in make_env:{task}")
    if suite == "dmc":
        import envs.dmc as dmc
        #[todo] start
        if "multitask" in task:
            env_id = id % len(config.train_env_name_list)
            env = dmc.DeepMindControl(
                config.train_env_name_list[env_id], config.action_repeat, config.size, seed=config.seed, render_image=config.render_image
            )
            if "proprio" in task:
                config.num_states = config.specific_num_states # 24
                config.num_actions = config.specific_num_actions # 6
                env = wrappers.MultitaskWrapper(env, action_dim=config.num_actions, obs_dim=config.num_states)
            else:
                config.num_actions = config.specific_num_actions # 6
                env = wrappers.MultitaskWrapper(env, action_dim=config.num_actions, obs_type="rgb")

        else:
        #[todo] end
            env = dmc.DeepMindControl(
                task, config.action_repeat, config.size, seed=config.seed, render_image=config.render_image #[todo]
            )
        env = wrappers.NormalizeActions(env)
    elif suite == "atari":
        import envs.atari as atari
        #[todo] start
        if task == "multitask":
            env_id = id % len(config.train_env_name_list)
            env = atari.Atari(
                config.train_env_name_list[env_id],
                config.action_repeat,
                config.size,
                gray=config.grayscale,
                noops=config.noops,
                lives=config.lives,
                sticky=config.stickey,
                actions=config.actions,
                resize=config.resize,
                seed=config.seed,
            )
        #[todo] end
        else:
            env = atari.Atari(
                task,
                config.action_repeat,
                config.size,
                gray=config.grayscale,
                noops=config.noops,
                lives=config.lives,
                sticky=config.stickey,
                actions=config.actions,
                resize=config.resize,
                seed=config.seed, #[todo]
            )
        env = wrappers.OneHotAction(env)
    #[todo] start
    elif suite == "metaworld":
        import multiprocessing as mp
        mp.set_start_method("spawn", force=True)
        import envs.meta_world.metaworld_old as metaworld
        env_id = id % len(config.train_env_name_list)
        env = metaworld.MetaWorld(config.train_env_name_list[env_id], seed=config.seed, size=config.size)
        env = metaworld.NormalizedEnvWrapper(env, normalize_reward=True)
        env = wrappers.NormalizeActions(env)
    #[todo] end
    elif suite == "dmlab":
        import envs.dmlab as dmlab

        env = dmlab.DeepMindLabyrinth(
            task,
            mode if "train" in mode else "test",
            config.action_repeat,
            seed=config.seed, #[todo]
        )
        env = wrappers.OneHotAction(env)
    elif suite == "memorymaze":
        from envs.memorymaze import MemoryMaze

        env = MemoryMaze(task, seed=config.seed) #[todo]
        env = wrappers.OneHotAction(env)
    elif suite == "crafter":
        import envs.crafter as crafter

        env = crafter.Crafter(task, config.size, seed=config.seed) #[todo]
        env = wrappers.OneHotAction(env)
    elif suite == "minecraft":
        import envs.minecraft as minecraft

        env = minecraft.make_env(task, size=config.size, break_speed=config.break_speed)
        env = wrappers.OneHotAction(env)
    #[todo] start
    elif suite == "mt160":
        import envs.MT160.mt160 as mt160
        env = mt160.Mt160()
    #[todo] end
    else:
        raise NotImplementedError(suite)
    if suite != "mt160": #[todo]
        env = wrappers.TimeLimit(env, config.time_limit)
        env = wrappers.SelectAction(env, key="action")
        env = wrappers.UUID(env)
        if suite == "minecraft":
            env = wrappers.RewardObs(env)
    return env


def main(config):
    #[todo] start
    if not config.add_token_embed:
        print(f"Not use token embed")
        if config.encoder["mlp_keys"] == config.decoder["mlp_keys"]:
            if config.encoder["mlp_keys"] == "^token_embed$":
                config.encoder["mlp_keys"] = '$^'
                config.decoder["mlp_keys"] = '$^'
            elif config.encoder["mlp_keys"] == '.*':
                config.encoder["mlp_keys"] = r"^(?!token_embed$).*"
                config.decoder["mlp_keys"] = r"^(?!token_embed$).*"
            elif config.encoder["mlp_keys"] == "^(token_embed|state)$":
                config.encoder["mlp_keys"] = "^state$"
                config.decoder["mlp_keys"] = "^state$"
            else:
                raise NotImplementedError("Need to improve")
        else:
            raise NotImplementedError("Need to improve")

    if "dmc_multitask" in config.task:
        obs_filter = {"image", "is_terminal", "is_first", "state", "token_embed", "discount","reward","action","logprob"} #集合的形式更方便查找(vision模式不会有state)
    else:
        obs_filter = None
    config.obs_filter = obs_filter

    if config.wm_with_moe and "RSSM" in config.expert_type:
        config.multi_stage = True
        config.wm_use_router = True

    #[todo] end
    tools.set_seed_everywhere(config.seed)
    if config.deterministic_run:
        tools.enable_deterministic_run()
    logdir = pathlib.Path(config.logdir).expanduser()
    config.steps //= config.action_repeat
    config.eval_every //= config.action_repeat
    config.log_every //= config.action_repeat
    config.time_limit //= config.action_repeat
    #[todo] start
    if config.train_env_name_list != "None":
        config.train_env_name_list = tools.config2list(config.train_env_name_list)
        if config.envs < len(config.train_env_name_list):
            if config.envs == 1:
                config.envs = len(config.train_env_name_list)
            else:
                raise NotImplementedError("Please modify config.envs")
        else:
            if config.envs % len(config.train_env_name_list) != 0:
                raise NotImplementedError("Please modify config.envs")

    if config.add_token_embed:
        raise NotImplementedError()
    else:
        task_desc_dict = None

    # config.actor_train_seperate_interval = config.actor_train_seperate_interval * len(config.train_env_name_list)
    config.eval_every = int(config.eval_every * len(config.train_env_name_list))

    config.dpmm_every = config.dpmm_every * len(config.train_env_name_list)
    config.init_dpmm_steps = config.init_dpmm_steps * len(config.train_env_name_list)

    save_steps_list = [int(config.steps * 0.9), int(config.steps * 0.8), int(config.steps * 0.6)]
    suite, task = config.task.split("_", 1)
    if suite == "metaworld":
        save_steps_list = [int(200000 * 0.8), int(150000 * 0.8)]
    if suite == "dmc" and "multitask" in task:
        if config.steps == 2e5:
            save_steps_list.extend([int(1e5 * 0.9), int(1e5 * 0.8), int(1e5 * 0.6)])
    if config.multi_actor:
        config.steps = int(config.steps * len(config.train_env_name_list))
        save_steps_list = [i * len(config.train_env_name_list) for i in save_steps_list]
        config.eval_episode_num = config.eval_episode_num * len(config.train_env_name_list)

    if suite != "metaworld":
        save_steps_list = []


    if len(save_steps_list) > 2:
        min_diff = min([abs(save_steps_list[i+1] - save_steps_list[i]) for i in range(len(save_steps_list)-1)])
        temp_thres = config.prefill if not config.multi_actor else config.prefill * len(config.train_env_name_list)
        if min_diff <= temp_thres:
            raise NotImplementedError("Please modify code using save_steps_list")

    #[todo] end


    print("Logdir", logdir)
    logdir.mkdir(parents=True, exist_ok=True)
    #[todo] start
    if config.multi_actor and config.multi_actor_mode == "distinct": #[todo]
        config.traindir = {n: config.traindir / n if config.traindir else logdir / "train_eps" / n for n in config.train_env_name_list}
        config.evaldir = {n: config.evaldir / n if config.evaldir else logdir / "eval_eps" / n for n in config.train_env_name_list}


    #[todo] end
    else:
        config.traindir = config.traindir or logdir / "train_eps"
        config.evaldir = config.evaldir or logdir / "eval_eps"
    #[todo] start
    if config.multi_actor and config.multi_actor_mode == "distinct":
        for dir in config.traindir.values():
            dir.mkdir(parents=True, exist_ok=True)
        for dir in config.evaldir.values():
            dir.mkdir(parents=True, exist_ok=True)

        if config.wm_with_moe and config.expert_type == "RSSM_append":
            current_group = os.path.basename(config.logdir)
            load_ckpt_step = int(config.steps * config.moe_start_steps_ratio)
            load_ckpt_step = tools.find_closest_pt_files_step(logdir / f"../../{current_group}",load_ckpt_step,config.eval_every)
            #复制load_ckpt_step为止的所有train_eps
            if count_steps(config.traindir) < load_ckpt_step:
                for name in config.train_env_name_list:
                    target_step = math.ceil(load_ckpt_step / len(config.train_env_name_list))
                    copy_npz_files(logdir / f"../../{current_group}/train_eps/{name}", config.traindir[name], target_step)
    #[todo] end
    else:
        config.traindir.mkdir(parents=True, exist_ok=True)
        config.evaldir.mkdir(parents=True, exist_ok=True)


    step = count_steps(config.traindir)
    # step in logger is environmental step
    logger = tools.Logger(logdir, config.action_repeat * step)

    print("Create envs.")
    if config.offline_traindir:
        if config.multi_actor and config.multi_actor_mode == "distinct":  # [todo]
            directory = {n: config.offline_traindir.format(**vars(config)) / n for n in config.train_env_name_list}
        else:
            directory = config.offline_traindir.format(**vars(config))
    else:
        directory = config.traindir
    if config.multi_actor and config.multi_actor_mode == "distinct":  # [todo]
        train_eps = {n: tools.load_episodes(directory[n], limit=config.dataset_size) for n in
                     config.train_env_name_list}
    else:
        train_eps = tools.load_episodes(directory, limit=config.dataset_size)

    if config.offline_evaldir:
        if config.multi_actor and config.multi_actor_mode == "distinct":  # [todo]
            directory = {n: config.offline_evaldir.format(**vars(config)) / n for n in config.train_env_name_list}
        else:
            directory = config.offline_evaldir.format(**vars(config))
    else:
        directory = config.evaldir
    if config.multi_actor and config.multi_actor_mode == "distinct":  # [todo]
        eval_eps = {n: tools.load_episodes(directory[n], limit=1) for n in config.train_env_name_list}
    else:
        eval_eps = tools.load_episodes(directory, limit=1)
    make = lambda mode, id: make_env(config, mode, id)
    #[todo]
    if config.parallel:
        train_envs = [
            Parallel(lambda i=i: make_env(config, "train", i), "process")
            for i in range(config.envs)
        ]
        eval_envs = [
            Parallel(lambda i=i: make_env(config, "eval", i), "process")
            for i in range(config.envs)
        ]
    else:
        train_envs = [make("train", i) for i in range(config.envs)]
        eval_envs = [make("eval", i) for i in range(config.envs)]
        train_envs = [Damy(env) for env in train_envs]
        eval_envs = [Damy(env) for env in eval_envs]
    acts = train_envs[0].action_space
    print("Action Space", acts)
    # config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]
    #[todo] start
    suite, task = config.task.split("_", 1)
    if suite == "dmc" and "proprio" in task:
        max_state_dim = 0
        max_action_dim = 0
        action_mask={}
        for env in train_envs:
            max_state_dim = max(env.ori_observation_space.spaces['state'].shape[0], max_state_dim)
            env_acts = env.ori_action_space
            # print(f"{env.task} low:{env_acts.low}, high:{env_acts.high}")
            env_num_actions = env_acts.n if hasattr(env_acts, "n") else env_acts.shape[0]
            max_action_dim = max(env_num_actions, max_action_dim)
            action_mask[env.task] = env_num_actions
            # print(f"{env.task} obs_space[state].shape:{env.observation_space.spaces['state'].shape}")
            # print(f"{env.task} observation_space.spaces.keys():{env.observation_space.spaces.keys()}")
        print(f"max_state_dim:{max_state_dim},max_action_dim:{max_action_dim}")
        # config.num_states = max_state_dim
        # config.num_actions = max_action_dim
        config.action_mask = action_mask
    else:
        config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]
    # else:
    #     raise NotImplementedError("Need to modify")

    if not hasattr(acts, "discrete"):
        acts = gym.spaces.Box(low=acts.low[0], high=acts.high[0], shape=(config.num_actions,), dtype=np.float32)
    #[todo] end


    state = None
    if not config.offline_traindir:
        if config.multi_actor: #[todo]
            config.prefill = config.prefill * len(config.train_env_name_list)
        prefill = max(0, config.prefill - count_steps(config.traindir))
        print(f"Prefill dataset ({prefill} steps).")
        if hasattr(acts, "discrete"):
            random_actor = tools.OneHotDist(
                torch.zeros(config.num_actions).repeat(config.envs, 1)
            )
        else:
            random_actor = torchd.independent.Independent(
                torchd.uniform.Uniform(
                    torch.tensor(acts.low).repeat(config.envs, 1),
                    torch.tensor(acts.high).repeat(config.envs, 1),
                ),
                1,
            )

        def random_agent(o, d, s, **kwargs): #[todo]
            action = random_actor.sample()
            logprob = random_actor.log_prob(action)
            return {"action": action, "logprob": logprob}, None
        #[todo] start
        if config.wm_with_moe and config.expert_type == "RSSM":
            head_with_moe = config.head_with_moe
            encoder_with_moe = config.encoder_with_moe
            config.head_with_moe = False
            config.encoder_with_moe = False
            prefill_agent = Dreamer(
                train_envs[0].observation_space,
                train_envs[0].action_space,
                config,
                logger,
                dataset=None,
            ).to(config.device)
            temp_prefix_dict = tools.get_list_prefix(config.train_env_name_list)
            # 其实这里可以改进为直接读取logdir下有哪些文件夹，不过强行指定的话方便记得这里可能需要修改
            if (logdir / f"group_0").exists():
                prefix_dict = {f"group_{i}": None for i in range(config.group_num)}
            elif (logdir / f"{next(iter(temp_prefix_dict.keys()))}").exists():
                prefix_dict = temp_prefix_dict
            else:
                raise NotImplementedError(f"Not implemented grouping method.")
            agent_olds = {}
            # when using this code, logdir must have corresponding ckpt
            for prefix in prefix_dict.keys():
                sub_dir_path = os.path.join(config.logdir, f"{prefix}/train_eps")
                env_in_prefix = os.listdir(sub_dir_path)
                temp_config = copy.copy(config)
                temp_config.train_env_name_list = env_in_prefix
                agent_temp = Dreamer(
                    train_envs[0].observation_space,
                    train_envs[0].action_space,
                    temp_config,
                    logger,
                    dataset=None,
                ).to(config.device)
                load_ckpt_step = int(
                    config.steps / len(config.train_env_name_list) * config.moe_start_steps_ratio * len(env_in_prefix))
                ckpt_path = logdir / f"{prefix}/{load_ckpt_step}.pt"
                checkpoint = torch.load(ckpt_path, map_location=config.device)
                agent_temp.load_state_dict(checkpoint["agent_state_dict"])
                tools.recursively_load_optim_state_dict(agent_temp, checkpoint["optims_state_dict"])
                agent_olds[prefix] = agent_temp
            # 合并所有 T 中的 I
            merged_T = nn.ModuleDict()
            for name, agent_old in agent_olds.items():
                for key, module_I in agent_old._task_behavior.items():
                    merged_T[key] = module_I
            prefill_agent._task_behavior = merged_T

            prefill_agent = functools.partial(prefill_agent, training=False)

            config.head_with_moe = head_with_moe
            config.encoder_with_moe = encoder_with_moe
        elif config.wm_with_moe and config.expert_type == "RSSM_append":
            temp_config = copy.copy(config)
            temp_config.encoder_with_moe = False
            temp_config.wm_with_moe = False
            temp_config.head_with_moe = False
            prefill_agent = Dreamer(
                train_envs[0].observation_space,
                train_envs[0].action_space,
                temp_config,
                logger,
                dataset=None,
            ).to(config.device)
            load_ckpt_step = int(config.steps * config.moe_start_steps_ratio)
            load_ckpt_step = tools.find_closest_pt_files_step(logdir / f"../../{current_group}",load_ckpt_step,config.eval_every)
            current_group = os.path.basename(config.logdir)
            ckpt_path = logdir / f"../../{current_group}/{load_ckpt_step}.pt"
            checkpoint = torch.load(ckpt_path, map_location=config.device)
            prefill_agent.load_state_dict(checkpoint["agent_state_dict"])
            tools.recursively_load_optim_state_dict(prefill_agent, checkpoint["optims_state_dict"])
            prefill_agent = functools.partial(prefill_agent, training=False)
            # logger.step += load_ckpt_step * config.action_repeat
        else:
            prefill_agent = random_agent

        #[todo] end

        state = tools.simulate(
            prefill_agent, #[todo]
            train_envs,
            train_eps,
            config.traindir,
            logger,
            limit=config.dataset_size,
            steps=prefill,
            train_env_name_list=config.train_env_name_list,  # [todo]
            task_desc_dict=task_desc_dict, #[todo]
            obs_filter=obs_filter, #[todo]

        )
        logger.step += prefill * config.action_repeat
        print(f"Logger: ({logger.step} steps).")
        del prefill_agent #[todo]

    print("Simulate agent.")
    if config.multi_actor and config.multi_actor_mode == "distinct":  # [todo]
        train_dataset = {n: make_dataset(train_eps[n], config) for n in config.train_env_name_list}
        eval_dataset = {n: make_dataset(eval_eps[n], config) for n in config.train_env_name_list}
    else:
        train_dataset = make_dataset(train_eps, config)
        eval_dataset = make_dataset(eval_eps, config)
    #[todo] start
    if config.wm_with_moe and "RSSM" in config.expert_type:
        config.actor_train_seperate_interval = 0
        if config.expert_type == "RSSM":
            compile = config.compile
            config.compile = False
    #[todo] end
    agent = Dreamer(
        train_envs[0].observation_space,
        train_envs[0].action_space,
        config,
        logger,
        train_dataset,
    ).to(config.device)
    agent.requires_grad_(requires_grad=False)
    #[todo] start
    if config.wm_with_moe and "RSSM" in config.expert_type:
        temp_dir = copy.copy(logdir)
        logdir = logdir if config.expert_type == "RSSM" else logdir / "../.."
        config.grouping_dict = tools.get_grouping_dict(logdir)
        config.task_group_map = {
            task: group_name
            for group_name, task_list in config.grouping_dict.items()
            for task in task_list
        }
        print(f"config.task_group_map:{config.task_group_map}")
        if config.expert_type == "RSSM":
            config.compile = compile
        # prefix_dict = tools.get_list_prefix(config.train_env_name_list)
        temp_prefix_dict = tools.get_list_prefix(config.train_env_name_list)
        #其实这里可以改进为直接读取logdir下有哪些文件夹，不过强行指定的话方便记得这里可能需要修改
        if tools.has_group_dir_simple(logdir):
            prefix_dict = {f"group_{i}": None for i in range(config.group_num)}
            if config.expert_type == "RSSM_append":
                current_group = os.path.basename(config.logdir)
                prefix_dict.pop(current_group)

        elif (logdir / f"{next(iter(temp_prefix_dict.keys()))}").exists():
            prefix_dict = temp_prefix_dict
        else:
            raise NotImplementedError(f"Not implemented grouping method.")
        agent_olds = {}
        # optim_state_dicts_all = {}
        #when using this code, logdir musthave corresponding ckpt
        for prefix in prefix_dict.keys():
            ckpt_config_logdir = config.logdir if config.ckpt_logdir == "None" else config.ckpt_logdir
            ckpt_logdir = logdir if config.ckpt_logdir == "None" else pathlib.Path(config.ckpt_logdir).expanduser() / "../.."

            sub_dir_path = os.path.join(config.logdir if config.expert_type == "RSSM" else ckpt_config_logdir + r"/../..", f"{prefix}/train_eps")
            env_in_prefix = os.listdir(sub_dir_path)
            print(f"env_in_prefix:{env_in_prefix}")
            temp_config = copy.copy(config)
            temp_config.train_env_name_list = env_in_prefix
            temp_config.head_with_moe = False
            temp_config.encoder_with_moe = False
            temp_config.wm_with_moe = False
            agent_temp = Dreamer(
                train_envs[0].observation_space,
                train_envs[0].action_space,
                temp_config,
                logger,
                dataset=None,
            ).to(config.device)
            load_ckpt_step = int(config.steps / len(config.train_env_name_list) * config.moe_start_steps_ratio * len(env_in_prefix))
            load_ckpt_step = tools.find_closest_pt_files_step(ckpt_logdir / f"{prefix}",load_ckpt_step,config.eval_every * len(env_in_prefix))

            ckpt_path = ckpt_logdir / f"{prefix}/{load_ckpt_step}.pt"
            checkpoint = torch.load(ckpt_path, map_location=config.device)
            # print(f"ckpt['agent_state_dict'].keys():{checkpoint['agent_state_dict'].keys()}")
            agent_temp.load_state_dict(checkpoint["agent_state_dict"])
            tools.recursively_load_optim_state_dict(agent_temp, checkpoint["optims_state_dict"])
            agent_olds[prefix] = agent_temp
            # print(f"{prefix}:agent_temp.state_dict:{agent_temp.state_dict()}")


        if config.expert_type == "RSSM":
            agent._wm.experts = nn.ModuleDict({key: agent_olds[key]._wm for key in agent_olds.keys()})
            agent._wm.init_moe()
            # 合并所有 T (_task_behavior，此处默认是一个ModuleDict)中的 I(ImagBehavior)
            merged_T = nn.ModuleDict()
            for name, agent_old in agent_olds.items():
                for key, module_I in agent_old._task_behavior.items():
                    # new_key = f"{key}_from_{name}"
                    module_I._world_model = agent._wm
                    merged_T[key] = module_I
            agent._task_behavior = merged_T
            print(f"(1.0 - config.moe_start_steps_ratio):{(1.0 - config.moe_start_steps_ratio)}")
            config.steps = round(config.steps * (1.0 - config.moe_start_steps_ratio))

            if (
                config.compile and os.name != "nt"
            ):  # compilation is not supported on windows
                agent._wm = torch.compile(agent._wm)
                #task_behavior的values本身就是已经被compile过了的

        config.moe_start_steps = 0 #默认进入这个分支的时候一定是需要训练moe的时候
        save_steps_list = []  # 训练MoE不需要阶段性保存

        logdir = temp_dir


    video_pred_idx = 0
    if config.multi_actor:
        config.moe_start_steps = config.moe_start_steps * len(config.train_env_name_list)

    if not config.wm_with_moe:
        if config.resume == "Newest":
            save_name = os.path.basename(tools.get_latest_ckpt(logdir))
        else:
            save_name = 'latest.pt'
    elif config.expert_type == "RSSM_append" and not ((logdir / 'w_moe_latest.pt').exists()):
        current_group = os.path.basename(config.logdir)
        load_ckpt_step = int(config.steps * config.moe_start_steps_ratio)
        load_ckpt_step = tools.find_closest_pt_files_step(logdir / f"../../{current_group}", load_ckpt_step, config.eval_every)
        save_name = f"../../{current_group}/{load_ckpt_step}.pt"
        if not (logdir / save_name).exists():
            save_name = f"../../{current_group}/{int(load_ckpt_step - config.prefill)}.pt"
    else:
        if agent._step <= config.moe_start_steps:
            save_name = 'wo_moe_latest.pt'
        else:
            if not agent._wm.dynamics.has_moe:
                if config.expert_type == "RSSM_append":
                    agent._wm.experts = nn.ModuleDict({key: agent_olds[key]._wm for key in agent_olds.keys()})
                agent._wm.init_moe()
            load_ckpt_step = tools.find_closest_pt_files_step(logdir, agent._step, config.eval_every)
            if load_ckpt_step is not None:
                if load_ckpt_step >= agent._step:
                    save_name = f"{load_ckpt_step}.pt"
            else:
                save_name = 'w_moe_latest.pt'
    #[todo] end
    if (logdir / save_name).exists():
        print(f"logdir / save_name:{logdir / save_name}")
        checkpoint = torch.load(logdir / save_name)
        # if "_wm._orig_mod.experts" not in checkpoint["agent_state_dict"].keys() and agent._wm.experts is not None:
        #     agent._wm.remove_experts()  # 只保留所需要的moe，其它没用到的删掉
        try:
            agent.load_state_dict(checkpoint["agent_state_dict"])
        except:
            agent._wm.remove_experts()  # 只保留所需要的moe，其它没用到的删掉
            agent.load_state_dict(checkpoint["agent_state_dict"])
        tools.recursively_load_optim_state_dict(agent, checkpoint["optims_state_dict"])
        agent._should_pretrain._once = False

    #[todo] start
    if config.wm_with_moe and config.expert_type == "RSSM_append":
        if not agent._wm.dynamics.has_moe:
            agent._wm.experts = nn.ModuleDict({key: agent_olds[key]._wm for key in agent_olds.keys()})
            agent._wm.init_moe()
        del agent_olds,agent_temp
        agent._wm.remove_experts() #只保留所需要的moe，其它没用到的删掉
        gc.collect()
        torch.cuda.empty_cache()
        # print(f"agent._wm.experts:{agent._wm.experts}")

    def cleanup(reason):
        global already_called
        if already_called:
            return
        already_called = True

        print("\n==============================")
        print("Exit")
        print("Reason:", reason)
        if reason != "finish":
            print("Start Saving ckpt...")
            # if agent._step == count_steps(config.traindir) or agent._step >= config.steps:
            items_to_save = {
                "agent_state_dict": agent.state_dict(),
                "optims_state_dict": tools.recursively_collect_optim_state_dict(agent),
            }
            data_to_save = {"step": agent._step}
            if agent._step < config.steps:
                if not config.wm_with_moe:
                    suffix = 'latest'
                else:
                    if agent._step <= config.moe_start_steps:
                        suffix = 'wo_moe_latest'
                    else:
                        suffix = 'w_moe_latest'
            else:
                suffix = agent._step
            save_name = f"{suffix}.pt"
            torch.save(items_to_save, logdir / save_name)
        print("==============================\n")

    #  捕获 Ctrl+C / kill / 关闭窗口
    def signal_handler(sig, frame):
        cleanup(f"receive signal {sig}")
        sys.exit(0)

    # #  捕获未处理异常
    # def exception_handler(exc_type, exc, tb):
    #     traceback.print_exception(exc_type, exc, tb)
    #     cleanup("error")
    #     sys.exit(1)

    signal.signal(signal.SIGINT, signal_handler)  # Ctrl+C
    signal.signal(signal.SIGTERM, signal_handler)  # kill pid / docker stop / 关闭窗口

    def exception_handler(exc_type, exc, tb):
        traceback.print_exception(exc_type, exc, tb)
        cleanup("exception")
        sys.exit(1)

    sys.excepthook = exception_handler

    #  捕获正常退出（return / sys.exit / main结束）
    atexit.register(lambda: cleanup("finish"))
    #[todo] end

    print(f"agent._step:{agent._step},config.steps:{config.steps},config.eval_every:{config.eval_every}")
    # make sure eval will be executed once after config.steps
    # while agent._step < config.steps + config.eval_every:
    while agent._step <= config.steps + config.eval_every: #[todo]
        logger.write()
        if config.eval_episode_num > 0:
            print(f"eval_episode_num:{config.eval_episode_num}")
            print("Start evaluation.")
            eval_policy = functools.partial(agent, training=False)
            tools.simulate(
                eval_policy,
                eval_envs,
                eval_eps,
                config.evaldir,
                logger,
                is_eval=True,
                episodes=config.eval_episode_num,
                train_env_name_list=config.train_env_name_list,  # [todo]
                task_desc_dict=task_desc_dict, #[todo]
                obs_filter=obs_filter, #[todo]

            )
            if config.video_pred_log:
                if not isinstance(eval_dataset, dict):
                    # print(f"self._dataset not dict")
                    video_data = next(eval_dataset)
                    env_name = config.task.split("_", 1)[1]  # [todo]
                else:  # [todo]
                    env_name = config.train_env_name_list[video_pred_idx]
                    video_data = next(eval_dataset[env_name])
                    video_pred_idx = (video_pred_idx + 1) % len(config.train_env_name_list)
                video_pred = agent._wm.video_pred(video_data, env_name=env_name) #[todo]
                logger.video("eval_openl", to_np(video_pred))
        #[todo] start
        if agent._step >= config.steps:
            print(f"agent._step >= config.steps, finish training.")
            break
        #[todo] end
        print("Start training.")
        #[todo] start
        if config.wm_with_moe and agent._step >= config.moe_start_steps:
            if not agent._wm.dynamics.has_moe:
                agent._wm.init_moe()

        #[todo] end
        state = tools.simulate(
            agent,
            train_envs,
            train_eps,
            config.traindir,
            logger,
            limit=config.dataset_size,
            steps=config.eval_every,
            state=state,
            train_env_name_list=config.train_env_name_list, #[todo]
            task_desc_dict=task_desc_dict,  #[todo]
            obs_filter=obs_filter,  #[todo]

        )
        #[todo] start
        # print(f"agent.state_dict().keys():{agent.state_dict().keys()}")
        # keyset = set()
        # for key in tools.recursively_collect_optim_state_dict(agent).keys():
        #     tmp = key.split('.')[:2]
        #     tmp = ".".join(tmp)
        #     keyset.add(tmp)
        # print(f"optims_state_dict.keys:{keyset}")   #{'_wm._orig_mod', '_task_behavior._orig_mod'}
        items_to_save = {
            "agent_state_dict": agent.state_dict(),
            "optims_state_dict": tools.recursively_collect_optim_state_dict(agent),
        }
        data_to_save = {"step":agent._step}
        if agent._step < config.steps:
            if not config.wm_with_moe:
                suffix = 'latest'
            else:
                if agent._step <= config.moe_start_steps:
                    suffix = 'wo_moe_latest'
                else:
                    suffix = 'w_moe_latest'
        else:
            suffix = agent._step
        save_name =  f"{suffix}.pt"
        torch.save(items_to_save, logdir / save_name)

        if config.use_dpmm:
            agent._wm.save_dpmm()

        # if agent._step < config.steps and agent._step + config.eval_every >= config.steps:
        #     torch.save(items_to_save, logdir / f"{agent._step}.pt")
        # if (agent._step <=5e4*len(config.train_env_name_list) and agent._step + config.eval_every > 5e4*len(config.train_env_name_list)) or (agent._step <=5e4*len(config.train_env_name_list)+config.prefill and agent._step + config.eval_every > 5e4*len(config.train_env_name_list)+config.prefill):
        #     torch.save(items_to_save,logdir / f"{agent._step}.pt")

        if suite == "metaworld":
            specific_step = int(150000 * 0.8 * len(config.train_env_name_list))
            if (agent._step <= specific_step and agent._step + config.eval_every > specific_step) or (
                    agent._step <= specific_step + config.prefill and agent._step + config.eval_every >
                    specific_step + config.prefill):
                torch.save(items_to_save, logdir / f"{agent._step}.pt")
        if len(save_steps_list) > 0:
            if (agent._step <= save_steps_list[-1] and agent._step + config.eval_every > save_steps_list[-1]) or (agent._step <= save_steps_list[-1] + config.prefill and agent._step + config.eval_every > save_steps_list[-1] + config.prefill):
                torch.save(items_to_save,logdir / f"{agent._step}.pt")
            if agent._step > save_steps_list[-1] + config.prefill:
                save_steps_list.pop()

        #[todo] end
    for env in train_envs + eval_envs:
        try:
            env.close()
        except Exception:
            pass



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--configs", nargs="+")
    args, remaining = parser.parse_known_args()
    configs = yaml.safe_load(
        (pathlib.Path(sys.argv[0]).parent / "configs.yaml").read_text()
    )

    def recursive_update(base, update):
        for key, value in update.items():
            if isinstance(value, dict) and key in base:
                recursive_update(base[key], value)
            else:
                base[key] = value

    name_list = ["defaults", *args.configs] if args.configs else ["defaults"]
    defaults = {}
    for name in name_list:
        recursive_update(defaults, configs[name])
    parser = argparse.ArgumentParser()
    for key, value in sorted(defaults.items(), key=lambda x: x[0]):
        arg_type = tools.args_type(value)
        parser.add_argument(f"--{key}", type=arg_type, default=arg_type(value))


    main(parser.parse_args(remaining))
