import os
import time
from abc import ABC
from datetime import timedelta

import ray
import torch
from tqdm import tqdm

#from openrlhf.datasets import PromptDataset
from openrlhf.trainer.ppo_utils import AdaptiveKLController, FixedKLController
#from openrlhf.trainer.ppo_utils.experience_maker import RemoteExperienceMaker
from openrlhf.trainer.ppo_utils.NEW_MULTI import RemoteExperienceMaker
from openrlhf.trainer.ray.launcher import PPORayActorGroup
from openrlhf.utils import blending_datasets, get_tokenizer
from openrlhf.utils.deepspeed import DeepspeedStrategy
from openrlhf.utils.logging_utils import init_logger
from openrlhf.utils.remote_rm_utils import remote_rm_fn_ray
import random

torch.autograd.set_detect_anomaly(True)
logger = init_logger(__name__)

#api_base_urls_16 = ['http://localhost:8000', 'http://localhost:8001', 'http://localhost:8002', 'http://localhost:8003', 'http://localhost:8004', 'http://localhost:8005', 'http://localhost:8006', 'http://localhost:8007', 'http://localhost:8008', 'http://localhost:8009', 'http://localhost:8010', 'http://localhost:8011', 'http://localhost:8012', 'http://localhost:8013', 'http://localhost:8014', 'http://localhost:8015', 'http://localhost:8016', 'http://localhost:8017', 'http://localhost:8018', 'http://localhost:8019', 'http://localhost:8020', 'http://localhost:8021', 'http://localhost:8022', 'http://localhost:8023', 'http://localhost:8024', 'http://localhost:8025', 'http://localhost:8026', 'http://localhost:8027', 'http://localhost:8028', 'http://localhost:8029', 'http://localhost:8030', 'http://localhost:8031', 'http://localhost:8032', 'http://localhost:8033', 'http://localhost:8034', 'http://localhost:8035', 'http://localhost:8036', 'http://localhost:8037', 'http://localhost:8038', 'http://localhost:8039', 'http://localhost:8040', 'http://localhost:8041', 'http://localhost:8042', 'http://localhost:8043', 'http://localhost:8044', 'http://localhost:8045', 'http://localhost:8046', 'http://localhost:8047', 'http://localhost:8048', 'http://localhost:8049', 'http://localhost:8050', 'http://localhost:8051', 'http://localhost:8052', 'http://localhost:8053', 'http://localhost:8054', 'http://localhost:8055', 'http://localhost:8056', 'http://localhost:8057', 'http://localhost:8058', 'http://localhost:8059', 'http://localhost:8060', 'http://localhost:8061', 'http://localhost:8062', 'http://localhost:8063']  # API服务地址列表
api_base_urls_16 = ['http://localhost:8000', 'http://localhost:8001', 'http://localhost:8002', 'http://localhost:8003', 'http://localhost:8004', 'http://localhost:8005','http://localhost:8006', 'http://localhost:8007', 'http://localhost:8008', 'http://localhost:8009', 'http://localhost:8010', 'http://localhost:8011']  # API服务地址列表, 'http://localhost:8001', 'http://localhost:8002', 'http://localhost:8003'
select_num = 40
@ray.remote
class PPOTrainer(ABC):
    """
    Trainer for Proximal Policy Optimization (PPO) / REINFORCE++ / GRPO / RLOO and their variants.
    Single Controller with Multiple ActorGroups
    """

    def __init__(
        self,
        pretrain: str,
        strategy: DeepspeedStrategy,
        actor_model_group: PPORayActorGroup,
        critic_model_group: PPORayActorGroup,
        reward_model_group: PPORayActorGroup,
        reference_model_group: PPORayActorGroup,
        vllm_engines=None,
        prompt_max_len: int = 120,
        dataloader_pin_memory: bool = True,
        prompt_split: str = "train",
        eval_split: str = "test",
        **generate_kwargs,
    ) -> None:
        super().__init__()

        self.strategy = strategy
        self.args = strategy.args

        self.tokenizer = get_tokenizer(pretrain, None, "left", strategy, use_fast=not self.args.disable_fast_tokenizer)
        self.actor_model_group = actor_model_group
        self.critic_model_group = critic_model_group
        self.reward_model_group = reward_model_group
        self.reference_model_group = reference_model_group
        self.dataloader_pin_memory = dataloader_pin_memory
        self.vllm_engines = vllm_engines

        self.prompt_split = prompt_split
        self.eval_split = eval_split

        self.prompt_max_len = prompt_max_len
        self.generate_kwargs = generate_kwargs

        self.max_epochs = self.args.max_epochs
        self.remote_rm_url = self.args.remote_rm_url
        self.init_kl_coef = self.args.init_kl_coef
        self.kl_target = self.args.kl_target
        self.kl_horizon = self.args.kl_horizon

        self.freezing_actor_steps = getattr(self.args, "freezing_actor_steps", -1)

        if self.kl_target:
            self.kl_ctl = AdaptiveKLController(self.init_kl_coef, self.kl_target, self.kl_horizon)
        else:
            self.kl_ctl = FixedKLController(self.init_kl_coef)

        # 修改: 使用新的多轮对话经验生成器
        self.experience_maker = RemoteExperienceMaker(
            self.actor_model_group,
            self.critic_model_group,
            self.reward_model_group,
            self.reference_model_group,
            self.tokenizer,
            self.prompt_max_len,
            self.kl_ctl,
            self.strategy,
            self.remote_rm_url,
            vllm_engines=self.vllm_engines,
            packing_samples=self.strategy.args.packing_samples,
            api_base_urls=api_base_urls_16,  # 新增: API服务地址
            max_turns=35,  # 新增: 最大对话轮次
        )

        self.prepare_datasets()

        # wandb/tensorboard setting
        self._wandb = None
        self._tensorboard = None
        self.generated_samples_table = None
        if self.strategy.args.use_wandb:
            import wandb

            self._wandb = wandb
            if not wandb.api.api_key:
                wandb.login(key=self.strategy.args.use_wandb)
            wandb.init(
                entity=self.strategy.args.wandb_org,
                project=self.strategy.args.wandb_project,
                group=self.strategy.args.wandb_group,
                name=self.strategy.args.wandb_run_name,
                config=self.strategy.args.__dict__,
                reinit=True,
            )

            wandb.define_metric("train/global_step")
            wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True)
            wandb.define_metric("eval/epoch")
            wandb.define_metric("eval/*", step_metric="eval/epoch", step_sync=True)
            self.generated_samples_table = wandb.Table(columns=["global_step", "text", "reward"])

        # Initialize TensorBoard writer if wandb is not available
        if self.strategy.args.use_tensorboard and self._wandb is None:
            from torch.utils.tensorboard import SummaryWriter

            os.makedirs(self.strategy.args.use_tensorboard, exist_ok=True)
            log_dir = os.path.join(self.strategy.args.use_tensorboard, self.strategy.args.wandb_run_name)
            self._tensorboard = SummaryWriter(log_dir=log_dir)

    def fit(
        self,
    ) -> None:
        args = self.args

        # Load datasets
        num_rollouts_per_episodes = len(self.prompts_dataloader)

        # get eval and save steps
        if args.eval_steps == -1:
            args.eval_steps = num_rollouts_per_episodes  # Evaluate once per epoch
        if args.save_steps == -1:
            args.save_steps = float("inf")  # do not save ckpt

        # broadcast init checkpoint to vllm
        ckpt_path = os.path.join(args.ckpt_path, "_actor")
        if args.load_checkpoint and os.path.exists(ckpt_path) and not self.vllm_engines is None:
            # vLLM wakeup when vllm_enable_sleep
            if self.strategy.args.vllm_enable_sleep:
                from openrlhf.trainer.ray.vllm_engine import batch_vllm_engine_call

                batch_vllm_engine_call(self.vllm_engines, "wake_up")

            ref = self.actor_model_group.async_run_method(method_name="broadcast_to_vllm")
            ray.get(ref)

            # vLLM offload when vllm_enable_sleep
            if self.strategy.args.vllm_enable_sleep:
                batch_vllm_engine_call(self.vllm_engines, "sleep")

        # Restore step and start_epoch
        consumed_samples = ray.get(self.actor_model_group.async_run_method(method_name="get_consumed_samples"))[0]
        steps = consumed_samples // args.rollout_batch_size + 1
        start_episode = consumed_samples // args.rollout_batch_size // num_rollouts_per_episodes
        consumed_samples = consumed_samples % (num_rollouts_per_episodes * args.rollout_batch_size)

        print(f"consumed_samples: {consumed_samples}, steps: {steps}, start_episode: {start_episode}")

        for episode in range(start_episode, args.num_episodes):
            self.prompts_dataloader.set_epoch(
                episode, consumed_samples=0 if episode > start_episode else consumed_samples
            )
            self.prompts_dataloader.show_active_data()  # 显示当前 epoch (episode) 的活动数据
            print("lenth of prompts_dataloader:", len(self.prompts_dataloader))

            pbar = tqdm(
                range(self.prompts_dataloader.__len__()),
                desc=f"Episode [{episode + 1}/{args.num_episodes}]",
                disable=False,
            )
            skip_flag = 0
            for batch in self.prompts_dataloader:
                # if skip_flag > 0:
                #     continue
                # skip_flag = skip_flag + 1
                #临时解决方案，直接rollout batch size个样本
                task_ids = [item['task_id'] for item in batch]  # 提取 task_ids
                experiences = self.experience_maker.make_experience_list(task_ids, **self.generate_kwargs)
                sample0 = self.tokenizer.batch_decode(
                    experiences[0].sequences[0].unsqueeze(0), skip_special_tokens=True
                )
                #print(sample0)
                refs = self.actor_model_group.async_run_method_batch(method_name="append", experience=experiences)
                if self.critic_model_group is not None:
                    refs.extend(
                        self.critic_model_group.async_run_method_batch(method_name="append", experience=experiences)
                    )
                ray.get(refs)
                
                status = self.ppo_train(steps)

                if "kl" in status:
                    self.kl_ctl.update(status["kl"], args.rollout_batch_size * args.n_samples_per_prompt)
                pbar.set_postfix(status)

                # Add generated samples to status dictionary
                status["generated_samples"] = [sample0[0], experiences[0].info["reward"][0]]
                # logs/checkpoints
                client_states = {"consumed_samples": steps * args.rollout_batch_size}
                self.save_logs_and_checkpoints(args, steps, pbar, status, client_states)

                pbar.update()
                steps = steps + 1
  

        if self._wandb is not None:
            self._wandb.finish()
        if self._tensorboard is not None:
            self._tensorboard.close()

    def ppo_train(self, global_steps):
        status = {}

        # triger remote critic model training
        if self.critic_model_group is not None:
            # sync for deepspeed_enable_sleep
            if self.strategy.args.deepspeed_enable_sleep:
                ray.get(self.critic_model_group.async_run_method(method_name="reload_states"))

            critic_status_ref = self.critic_model_group.async_run_method(method_name="fit")

            if self.strategy.args.colocate_all_models or self.strategy.args.deepspeed_enable_sleep:
                status.update(ray.get(critic_status_ref)[0])
            if self.strategy.args.deepspeed_enable_sleep:
                ray.get(self.critic_model_group.async_run_method(method_name="offload_states"))

        # actor model training
        if global_steps > self.freezing_actor_steps:
            if self.strategy.args.deepspeed_enable_sleep:
                self.actor_model_group.async_run_method(method_name="reload_states")

            actor_status_ref = self.actor_model_group.async_run_method(method_name="fit", kl_ctl=self.kl_ctl.value)
            status.update(ray.get(actor_status_ref)[0])

            if self.strategy.args.deepspeed_enable_sleep:
                self.actor_model_group.async_run_method(method_name="offload_states")

            # 4. broadcast weights to vllm engines
            if self.vllm_engines is not None:
                if self.strategy.args.vllm_enable_sleep:
                    from openrlhf.trainer.ray.vllm_engine import batch_vllm_engine_call

                    batch_vllm_engine_call(self.vllm_engines, "wake_up")

                ray.get(self.actor_model_group.async_run_method(method_name="broadcast_to_vllm"))

                if self.strategy.args.vllm_enable_sleep:
                    batch_vllm_engine_call(self.vllm_engines, "sleep")

        # 5. wait remote critic model training done
        if self.critic_model_group and not self.strategy.args.colocate_all_models:
            status.update(ray.get(critic_status_ref)[0])

        return status

    def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}):
        if global_step % args.logging_steps == 0:
            # wandb
            if self._wandb is not None:
                # Add generated samples to wandb using Table
                if "generated_samples" in logs_dict:
                    # https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737
                    new_table = self._wandb.Table(
                        columns=self.generated_samples_table.columns, data=self.generated_samples_table.data
                    )
                    new_table.add_data(global_step, *logs_dict.pop("generated_samples"))
                    self.generated_samples_table = new_table
                    self._wandb.log({"train/generated_samples": new_table})
                logs = {
                    "train/%s" % k: v
                    for k, v in {
                        **logs_dict,
                        "global_step": global_step,
                    }.items()
                }
                self._wandb.log(logs)
            # TensorBoard
            elif self._tensorboard is not None:
                for k, v in logs_dict.items():
                    if k == "generated_samples":
                        # Record generated samples in TensorBoard using simple text format
                        text, reward = v
                        formatted_text = f"Sample:\n{text}\n\nReward: {reward:.4f}"
                        self._tensorboard.add_text("train/generated_samples", formatted_text, global_step)
                    else:
                        self._tensorboard.add_scalar(f"train/{k}", v, global_step)

        # # TODO: Add evaluation mechanism for PPO
        # if global_step % args.eval_steps == 0 and self.eval_dataloader and len(self.eval_dataloader) > 0:
        #     self.evaluate(self.eval_dataloader, global_step, args.eval_temperature, args.eval_n_samples_per_prompt)
        # save ckpt
        # TODO: save best model on dev, use loss/perplexity/others on whole dev dataset as metric
        if global_step % args.save_steps == 0:
            tag = f"global_step{global_step}"
            ref = self.actor_model_group.async_run_method(
                method_name="save_checkpoint", tag=tag, client_states=client_states
            )
            if self.critic_model_group is not None:
                ref.extend(self.critic_model_group.async_run_method(method_name="save_checkpoint", tag=tag))
            ray.get(ref)

    def evaluate(self, eval_dataloader, global_step, temperature=0.6, n_samples_per_prompt=1):
        """Evaluate model performance on eval dataset.

        Args:
            eval_dataloader: DataLoader containing evaluation prompts, labels and data sources
            global_step: Current training step for logging
            n_samples_per_prompt: Number of samples to generate per prompt for pass@k calculation
        """
        start_time = time.time()
        logger.info(f"⏰ Evaluation start time: {time.strftime('%Y-%m-%d %H:%M:%S')}")

        # vLLM wakeup when vllm_enable_sleep
        if self.strategy.args.vllm_enable_sleep:
            from openrlhf.trainer.ray.vllm_engine import batch_vllm_engine_call

            batch_vllm_engine_call(self.vllm_engines, "wake_up")

        with torch.no_grad():
            # First collect all prompts and labels
            all_prompts = []
            all_labels = []
            all_datasources = []

            for datasources, prompts, labels in eval_dataloader:
                all_prompts.extend(prompts)
                all_labels.extend(labels)
                all_datasources.extend(datasources)

            # Generate samples and calculate rewards
            generate_kwargs = self.generate_kwargs.copy()
            generate_kwargs["temperature"] = temperature
            generate_kwargs["n_samples_per_prompt"] = n_samples_per_prompt
            samples = self.experience_maker.generate_samples(all_prompts, all_labels, **generate_kwargs)
            queries_list = self.tokenizer.batch_decode(samples.sequences, skip_special_tokens=False)

            # duplicate prompts and labels for each sample
            all_prompts = sum([[prompt] * n_samples_per_prompt for prompt in all_prompts], [])
            all_labels = sum([[label] * n_samples_per_prompt for label in all_labels], [])

            # Calculate rewards
            if self.experience_maker.custom_reward_func:
                # Let Ray automatically distribute the workload across available resources
                batch_size = self.strategy.args.micro_rollout_batch_size
                num_chunks = (len(queries_list) + batch_size - 1) // batch_size
                r_refs = []
                for i in range(num_chunks):
                    start_idx = i * batch_size
                    end_idx = min((i + 1) * batch_size, len(queries_list))
                    r = self.experience_maker.custom_reward_func.remote(
                        queries_list[start_idx:end_idx],
                        all_prompts[start_idx:end_idx],
                        all_labels[start_idx:end_idx],
                    )
                    r_refs.append(r)
            else:
                # Distribute data across different remote reward function servers
                num_servers = len(self.remote_rm_url)
                batch_size = (len(queries_list) + num_servers - 1) // num_servers
                r_refs = []
                for i in range(num_servers):
                    start_idx = i * batch_size
                    end_idx = min((i + 1) * batch_size, len(queries_list))
                    rm = self.remote_rm_url[i]
                    r = remote_rm_fn_ray.remote(
                        rm,
                        queries=queries_list[start_idx:end_idx],
                        prompts=all_prompts[start_idx:end_idx],
                        labels=all_labels[start_idx:end_idx],
                    )
                    r_refs.append(r)

            # Reshape rewards to (num_prompts, n_samples_per_prompt)
            rewards = ray.get(r_refs)
            rewards = torch.cat(rewards, dim=0).reshape(-1, n_samples_per_prompt)

            # Collect local statistics for each data source
            global_metrics = {}  # {datasource: {"pass{n_samples_per_prompt}": 0, "pass1": 0, "count": 0}}

            for i, datasource in enumerate(all_datasources):
                if datasource not in global_metrics:
                    global_metrics[datasource] = {f"pass{n_samples_per_prompt}": 0, "pass1": 0, "count": 0}

                # Calculate pass@k and pass@1
                prompt_rewards = rewards[i]
                if n_samples_per_prompt > 1:
                    global_metrics[datasource][f"pass{n_samples_per_prompt}"] += prompt_rewards.max().float().item()
                global_metrics[datasource]["pass1"] += prompt_rewards.mean().float().item()
                global_metrics[datasource]["count"] += 1

            # Calculate global averages
            logs = {}
            for datasource, metrics in global_metrics.items():
                logs[f"eval_{datasource}_pass{n_samples_per_prompt}"] = (
                    metrics[f"pass{n_samples_per_prompt}"] / metrics["count"]
                )
                logs[f"eval_{datasource}_pass1"] = metrics["pass1"] / metrics["count"]

            # Log to wandb/tensorboard
            if self._wandb is not None:
                logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": global_step}.items()}
                self._wandb.log(logs)
            elif self._tensorboard is not None:
                for k, v in logs.items():
                    self._tensorboard.add_scalar(f"eval/{k}", v, global_step)

        if self.strategy.args.vllm_enable_sleep:
            batch_vllm_engine_call(self.vllm_engines, "sleep")

        end_time = time.time()
        duration = end_time - start_time
        time_str = str(timedelta(seconds=duration)).split(".")[0]
        logger.info(f"✨ Evaluation completed in {time_str}")

    def prepare_datasets(self):
        args = self.args
        strategy = self.strategy

        # 使用固定的 task_ids 作为数据源
        # 难度1和难度2
        task_ids = ['07b42fd_1', '07b42fd_2', '07b42fd_3', '229360a_1', '229360a_2', '229360a_3', '27e1026_1', '27e1026_2', '27e1026_3', '287e338_1', '287e338_2', '287e338_3', '692c77d_1', '692c77d_2', '692c77d_3', '82e2fac_1', '82e2fac_2', '82e2fac_3', 'aa8502b_1', 'aa8502b_2', 'aa8502b_3', 'b7a9ee9_1', 'b7a9ee9_2', 'b7a9ee9_3', 'c901732_1', 'c901732_2', 'c901732_3', 'ccb4494_1', 'ccb4494_2', 'ccb4494_3', 'ce359b5_1', 'ce359b5_2', 'ce359b5_3', 'e7a10f8_1', 'e7a10f8_2', 'e7a10f8_3', 'e85d92a_1', 'e85d92a_2', 'e85d92a_3', 'e3d6c94_1', 'e3d6c94_2', 'e3d6c94_3', 'd0b1f43_1', 'd0b1f43_2', 'd0b1f43_3', '2a163ab_1', '2a163ab_2', '2a163ab_3', '60d0b5b_1', '60d0b5b_2', '60d0b5b_3', '6ea6792_1', '6ea6792_2', '6ea6792_3', '29caf6f_1', '29caf6f_2', '29caf6f_3', 'cf6abd2_1', 'cf6abd2_2', 'cf6abd2_3', '771d8fc_1', '771d8fc_2', '771d8fc_3', '7d7fbf6_1', '7d7fbf6_2', '7d7fbf6_3', '76f2c72_1', '76f2c72_2', '76f2c72_3', '302c169_1', '302c169_2', '302c169_3']
        #task_ids = ['07b42fd_1', '07b42fd_2', '229360a_3', '27e1026_1']
        #task_ids = ['27e1026_3', 'b7a9ee9_1', 'e3d6c94_3', 'c901732_2', 'aa8502b_1', '771d8fc_3', '302c169_2', '287e338_2', '60d0b5b_3', '229360a_3', '60d0b5b_1', '7d7fbf6_2', '76f2c72_2', '302c169_1', 'cf6abd2_3', '692c77d_3', '287e338_1', '692c77d_1', 'ccb4494_2', 'ccb4494_3', '2a163ab_2', 'e7a10f8_2', 'e3d6c94_1', '771d8fc_2', 'b7a9ee9_3', '7d7fbf6_1', '692c77d_2', 'cf6abd2_1', '2a163ab_3', '76f2c72_1', '2a163ab_1', 'd0b1f43_3']
        
        task_ids_eval = []

        # # 创建一个简单的数据集，每个任务ID对应一个样本
        # train_data = [{"task_id": task_id} for task_id in task_ids]

        # # 限制最大样本数量
        # train_data = train_data[:min(args.max_samples, len(train_data))]

        # 创建一个包含所有任务的数据集
        # 注意：这里不再限制最大样本数量，因为采样是在 DataLoader 内部完成的
        all_train_data = [{"task_id": task_id} for task_id in task_ids]
        
        # 假设你在参数中定义了 tasks_per_episode, e.g., args.tasks_per_episode = 40
        # 如果没有，可以先硬编码
        tasks_per_episode = select_num  # 每个 episode 需要随机采样的任务数量

        # 自定义数据加载器
        class CustomDataLoader:
            def __init__(self, all_data, batch_size, shuffle, tasks_per_epoch):
                """
                构造函数改造:
                - all_data: 包含所有任务的完整列表 (e.g., [{'task_id': '...'}, ...])
                - tasks_per_epoch: 每个 epoch (episode) 需要随机采样的任务数量
                """
                self.all_data = all_data
                self.batch_size = batch_size
                self.shuffle = shuffle
                self.tasks_per_epoch = tasks_per_epoch

                # 新增: 用于存放当前 epoch (episode) 的活动数据
                self.active_data = []
                self.index = 0
                self.epoch = 0

            def __iter__(self):
                # 迭代器现在基于 active_data
                if self.shuffle:
                    random.shuffle(self.active_data)
                self.index = 0
                return self

            def __next__(self):
                # next 方法也基于 active_data
                if self.index >= len(self.active_data):
                    # 注意：这里的 StopIteration 标志着一个 epoch (episode) 的批次结束
                    # 而不是所有数据的结束
                    raise StopIteration
                batch = self.active_data[self.index:self.index + self.batch_size]
                self.index += self.batch_size
                return batch
            def __len__(self):
                # 长度现在是 active_data 的长度除以 batch_size
                if not self.active_data:
                    return (self.tasks_per_epoch + self.batch_size - 1) // self.batch_size
                return (len(self.active_data) + self.batch_size - 1) // self.batch_size

            def show_active_data(self):
                """
                显示当前 epoch (episode) 的活动数据
                """
                print(f"Active data for epoch {self.epoch}:")
                print(f"Total tasks num: {len(self.active_data)}, \n Tasks: {self.active_data}")
            def sampler(self):
                return self

            def set_epoch(self, epoch, consumed_samples=0):
                """
                核心改造: 这是触发重新采样的关键
                """
                self.epoch = epoch

                # 1. 从所有数据中随机采样
                # 使用 min 来防止当总任务数小于采样数时报错
                num_to_sample = min(self.tasks_per_epoch, len(self.all_data))
                self.active_data = random.sample(self.all_data, k=num_to_sample)

                # 2. 重置索引，以支持从断点继续
                self.index = consumed_samples


        # 创建训练数据加载器，传入完整数据和每个epoch的采样数
        prompts_dataloader = CustomDataLoader(
            all_data=all_train_data, 
            batch_size=args.rollout_batch_size, 
            shuffle=True,
            tasks_per_epoch=tasks_per_episode
        )

        # 创建评估数据集（如果存在）
        if getattr(args, "eval_dataset", None):
            eval_data = [{"task_id": task_id} for task_id in task_ids[:min(args.max_samples, len(task_ids))]]
            eval_dataloader = CustomDataLoader(eval_data, 1, True)
        else:
            eval_dataloader = None

        self.prompts_dataloader = prompts_dataloader
        self.eval_dataloader = eval_dataloader
        # 现在每个 episode 的步数是固定的
        steps_per_episode = (tasks_per_episode + args.train_batch_size - 1) // args.train_batch_size
        self.max_steps = steps_per_episode * args.num_episodes * args.max_epochs

    def get_max_steps(self):
        return self.max_steps
