import time
from abc import ABC
from copy import deepcopy
from dataclasses import dataclass
from datetime import timedelta
from typing import List, Optional, Tuple, Union

import ray
import torch

from openrlhf.datasets.utils import zero_pad_sequences
from openrlhf.models.utils import compute_approx_kl, compute_reward, masked_mean, process_sequences
from openrlhf.trainer.ray.ilr_launcher import ILRRayActorGroup
from openrlhf.utils.logging_utils import init_logger
from openrlhf.utils.remote_rm_utils import remote_rm_fn_ray
import math
from openrlhf.trainer.ray.vllm_engine import batch_vllm_engine_call
logger = init_logger(__name__)


def to(tensor: Union[torch.Tensor, list[torch.Tensor]], device):
    if isinstance(tensor, list):
        return [to(t, device) for t in tensor]
    return tensor.to(device) if isinstance(tensor, torch.Tensor) else tensor


def pin_memory(tensor: Union[torch.Tensor, list[torch.Tensor]]):
    if isinstance(tensor, list):
        return [pin_memory(t) for t in tensor]
    return tensor.pin_memory() if isinstance(tensor, torch.Tensor) else tensor


@dataclass
class Experience:
    """Experience is a batch of data.
    These data should have the the sequence length and number of actions.
    Left padding for sequences is applied.

    Shapes of each tensor:
    sequences: (B, S)
    action_log_probs: (B, A)
    base_action_log_probs: (B, A)
    values: (B, A)
    returns: (B, A)
    advantages: (B, A)
    attention_mask: (B, S)
    action_mask: (B, A)
    kl: (B, A)

    "A" is the number of actions.
    """

    sequences: torch.Tensor
    action_log_probs: torch.Tensor
    base_action_log_probs: torch.Tensor
    values: torch.Tensor
    returns: Optional[torch.Tensor]
    advantages: Optional[torch.Tensor]
    attention_mask: Optional[torch.LongTensor]
    action_mask: Optional[torch.BoolTensor]
    info: Optional[dict]
    kl: Optional[torch.Tensor] = None

    @torch.no_grad()
    def to_device(self, device: torch.device):
        self.sequences = to(self.sequences, device)
        self.action_log_probs = to(self.action_log_probs, device)
        self.base_action_log_probs = to(self.base_action_log_probs, device)
        self.returns = to(self.returns, device)
        self.advantages = to(self.advantages, device)
        self.values = to(self.values, device)
        self.attention_mask = to(self.attention_mask, device)
        self.action_mask = to(self.action_mask, device)
        self.kl = to(self.kl, device)
        self.info = {key: to(value, device) for key, value in self.info.items()}
        return self

    def pin_memory(self):
        self.sequences = pin_memory(self.sequences)
        self.action_log_probs = pin_memory(self.action_log_probs)
        self.base_action_log_probs = pin_memory(self.base_action_log_probs)
        self.returns = pin_memory(self.returns)
        self.advantages = pin_memory(self.advantages)
        self.values = pin_memory(self.values)
        self.attention_mask = pin_memory(self.attention_mask)
        self.action_mask = pin_memory(self.action_mask)
        self.kl = pin_memory(self.kl)
        self.info = {key: pin_memory(value) for key, value in self.info.items()}
        return self


@dataclass
class Samples:
    """Samples is a batch of data.
    There can be 2 formats to store the samples, batched or packed.
    The batched format means padding is applied to the sequences, while the packed format
    will concatenate the prompt and response without padding.

    Shapes of each tensor, when 2 shapes are shown, the first one is for batched format
        and the second one is for packed format:
    sequences: (B, S) or (1, total_length), the tokens of both prompt and response.
    attention_mask: (B, S) or (1, total_length), the attention mask for sequences.
    action_mask: (B, A) or None, the action (response) mask to show which part of the
        sequence is the response. When the samples are packed, this is None.
        When the samples are not packed, we will use action_mask, so this is an int to
        show the size of action_mask. Otherwise, this is a tensor to show the number of
        actions for each sample.
    packed_seq_lens: None or (B,), the length of each sample in the packed samples.
    response_length: (B,), the number of tokens in the response.
    total_length: (B,), the total number of tokens in the sequences.
    prompts: the prompts used to generate responses
    """

    sequences: torch.Tensor
    attention_mask: Optional[torch.LongTensor]
    action_mask: Optional[torch.BoolTensor]
    response_length: torch.Tensor
    total_length: torch.Tensor
    prompts: list[str]
    labels: list[str]

    def __init__(
        self,
        sequences=None,
        attention_mask=None,
        action_mask=None,
        response_length=None,
        total_length=None,
        prompts=None,
        labels=None,
        packed_seq_lens=None,
    ):
        self.sequences = sequences
        self.attention_mask = attention_mask
        self.action_mask = action_mask
        self.response_length = response_length
        self.total_length = total_length
        self.prompts = prompts or []
        self.labels = labels or []
        self.packed_seq_lens = packed_seq_lens

    def split(self, split_size: int):
        sequences_list = self.sequences.split(split_size, dim=0)
        attention_mask_list = self.attention_mask.split(split_size, dim=0)
        action_mask_list = self.action_mask.split(split_size, dim=0)
        sample_list = []
        for i, (seq, mask, action_mask) in enumerate(zip(sequences_list, attention_mask_list, action_mask_list)):
            sample = Samples()
            sample.sequences = seq
            sample.attention_mask = mask
            sample.action_mask = action_mask
            sample.response_length = sample.action_mask.float().sum(dim=-1)
            sample.total_length = sample.attention_mask.float().sum(dim=-1)
            sample.prompts = self.prompts[i * split_size : (i + 1) * split_size]
            sample.labels = self.labels[i * split_size : (i + 1) * split_size]
            sample_list.append(sample)
        return sample_list


class RemoteExperienceMaker(ABC):
    def __init__(
        self,
        actor_model_group: ILRRayActorGroup,
        critic_model_group: ILRRayActorGroup,
        reward_model_group: ILRRayActorGroup,
        initial_model_group: ILRRayActorGroup,
        tokenizer_llm1,
        tokenizer_llm2,
        prompt_max_len: int,
        kl_controller_llm1,
        kl_controller_llm2,
        strategy=None,
        remote_rm_url: Union[list[str], str] = None,
        vllm_engines_llm1: List = None,
        vllm_engines_llm2: List = None,
        packing_samples=False,
        **kwargs,
    ):
        super().__init__()

        self.vllm_engines_llm1 = vllm_engines_llm1
        self.vllm_engines_llm2 = vllm_engines_llm2
        self.packing_samples = packing_samples
        self.actor_model_group = actor_model_group
        self.critic_model_group = critic_model_group
        self.reward_model_group = reward_model_group
        self.initial_model_group = initial_model_group
        self.tokenizer_llm1 = tokenizer_llm1
        self.tokenizer_llm2 = tokenizer_llm2
        self.prompt_max_len = prompt_max_len
        self.kl_ctl_llm1 = kl_controller_llm1
        self.kl_ctl_llm2 = kl_controller_llm2
        self.strategy = strategy
        self.advantage_estimator = strategy.args.advantage_estimator
        self.args = strategy.args

        # custom reward func for reinforced finetuning
        self.custom_reward_func = None
        self.remote_rm_url = [remote_rm_url] if isinstance(remote_rm_url, str) else remote_rm_url
        if remote_rm_url and remote_rm_url[0].endswith(".py"):
            print(f"Loading custom `reward_func(queries, prompts, labels)` from {remote_rm_url[0]}")
            import importlib.util

            spec = importlib.util.spec_from_file_location("reward_func", remote_rm_url[0])
            reward_module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(reward_module)
            self.custom_reward_func = ray.remote(reward_module.reward_func)

    # tokenizer
    def tokenize_fn(self, texts, max_length, padding=True, device=None, llm_mark=None):
        if llm_mark == 'llm1':
            tokenizer = self.tokenizer_llm1
        elif llm_mark == 'llm2':
            tokenizer = self.tokenizer_llm2
        if not padding:
            # when padding is False, return tokenized texts as list
            return tokenizer(
                texts,
                add_special_tokens=False,
                max_length=max_length,
                truncation=True,
            )
        batch = tokenizer(
            texts,
            return_tensors="pt",
            add_special_tokens=False,
            max_length=max_length,
            padding=True,
            truncation=True,
        )
        return {k: v.to(device) for k, v in batch.items()}

    @torch.no_grad()
    def make_experience_list(
        self, all_prompts_llm1: Union[str, List[str]], all_labels_llm1, all_levels_llm1, all_prompts_llm2: Union[str, List[str]], all_labels_llm2, all_levels_llm2, **generate_kwargs
    ) -> List[Experience]:
        """
        Make a list of experience with the micro_rollout_batch_size.

        This method will first calculate the response sequences and rewards for the given prompts.
        Then, if we need certain processing for the rewards or do certain filtering, we can process the rollout as a whole.
        After that, we will calculate the advantages and returns for each experience.
        """
        args = self.strategy.args

#         # vLLM wakeup when vllm_enable_sleep
#         if self.strategy.args.vllm_enable_sleep:
#             from openrlhf.trainer.ray.vllm_engine import batch_vllm_engine_call

#             batch_vllm_engine_call(self.vllm_engines_llm1, "wake_up")
#             batch_vllm_engine_call(self.vllm_engines_llm2, "wake_up")

        rollout_samples_llm1, rollout_samples_llm2 = self.generate_samples(all_prompts_llm1, all_labels_llm1, all_levels_llm1, all_prompts_llm2, all_labels_llm2, all_levels_llm2, **generate_kwargs)

#         # vLLM offload when vllm_enable_sleep
#         if self.strategy.args.vllm_enable_sleep:
#             batch_vllm_engine_call(self.vllm_engines_llm1, "sleep")
#             batch_vllm_engine_call(self.vllm_engines_llm2, "sleep")

        # Make experiences (models forward: logprobs, values, rewards, and kl divergence)
        experiences_llm1 = self.make_experience(rollout_samples_llm1, 'llm1')
        experiences_llm2 = self.make_experience(rollout_samples_llm2, 'llm2')

        # Process experiences (reward shaping, etc.)
        experiences_llm1 = self.compute_advantages_and_returns(experiences_llm1, experiences_llm2, 'llm1')
        experiences_llm2 = self.compute_advantages_and_returns(experiences_llm2, experiences_llm1, 'llm2')
        return experiences_llm1, experiences_llm2

    @torch.no_grad()
    def make_experience(self, samples_list: List[Samples], llm_mark) -> List[Experience]:
        """
        Turn samples into experience by calculating logprobs, values, rewards, and kl divergence.
        """
        if llm_mark == 'llm1':
            tokenizer = self.tokenizer_llm1
        elif llm_mark == 'llm2':
            tokenizer = self.tokenizer_llm2
        start_time = time.time()
        logger.info(f"🚀 Starting experience making with {len(samples_list[0].sequences) * len(samples_list)} batches")

        args = self.strategy.args
        device = "cpu"
        experiences = []

        # Extract all information from samples in one pass
        # Convert samples into lists of tensors and metadata for batch processing
        sequences_list = [s.sequences for s in samples_list]
        attention_mask_list = [s.attention_mask for s in samples_list]
        action_mask_list = [s.action_mask for s in samples_list]
        prompts_list = [p for s in samples_list for p in s.prompts]
        labels_list = [l for s in samples_list for l in s.labels]

        # Batch call reward model
        r_refs = None
        if not self.remote_rm_url:
            sequences_reward_list = [tokenizer.batch_decode(seq, skip_special_tokens=True) for seq in sequences_list]

            r_refs = self.reward_model_group.async_run_method_batch(
                method_name="forward",
                sequences=sequences_reward_list,
                attention_mask=attention_mask_list,
                pad_sequence=[True] * len(samples_list),
            )
        else:
            queries_list = sum(
                [self.tokenizer.batch_decode(seq, skip_special_tokens=False) for seq in sequences_list], []
            )

            if self.custom_reward_func:
                # Let Ray automatically distribute the workload across available resources
                batch_size = self.strategy.args.micro_rollout_batch_size
                num_chunks = (len(queries_list) + batch_size - 1) // batch_size
                r_refs = []
                for i in range(num_chunks):
                    start_idx = i * batch_size
                    end_idx = min((i + 1) * batch_size, len(queries_list))
                    r = self.custom_reward_func.remote(
                        queries_list[start_idx:end_idx],
                        prompts_list[start_idx:end_idx],
                        labels_list[start_idx:end_idx],
                    )
                    r_refs.append(r)
            else:
                # Distribute data across different remote reward function servers
                num_servers = len(self.remote_rm_url)
                batch_size = (len(queries_list) + num_servers - 1) // num_servers
                r_refs = []
                for i in range(num_servers):
                    start_idx = i * batch_size
                    end_idx = min((i + 1) * batch_size, len(queries_list))
                    rm = self.remote_rm_url[i]
                    r = remote_rm_fn_ray.remote(
                        rm,
                        queries=queries_list[start_idx:end_idx],
                        prompts=prompts_list[start_idx:end_idx],
                        labels=labels_list[start_idx:end_idx],
                    )
                    r_refs.append(r)

        # Sync to avoid GPU OOM when colocate models
        if args.colocate_all_models and not self.remote_rm_url:
            ray.get(r_refs)
            ray.get(self.reward_model_group.async_run_method(method_name="empty_cache"))

        # Batch call actor model
        action_log_probs_ref = self.actor_model_group.async_run_method_batch(
            method_name="forward",
            sequences=sequences_list,
            action_mask=action_mask_list,
            attention_mask=attention_mask_list,
            llm_mark=[llm_mark] * len(samples_list),
        )

        # Sync to avoid GPU OOM when colocate models
        if args.colocate_all_models or args.colocate_actor_ref:
            ray.get(action_log_probs_ref)
            ray.get(self.actor_model_group.async_run_method(method_name="empty_cache"))

        # Batch call critic model
        if self.critic_model_group is not None:
            if args.colocate_critic_reward and not self.remote_rm_url:
                ray.get(r_refs)
                ray.get(self.reward_model_group.async_run_method(method_name="empty_cache"))

            
            value_ref = self.critic_model_group.async_run_method_batch(
                method_name="forward",
                sequences=sequences_list,
                action_mask=action_mask_list,
                attention_mask=attention_mask_list,
            )
            if args.colocate_all_models or args.colocate_critic_reward:
                ray.get(value_ref)
                ray.get(self.critic_model_group.async_run_method(method_name="empty_cache"))
        else:
            value_ref = ray.put([[None]] * (len(samples_list) * args.ring_attn_size * args.ds_tensor_parallel_size))

        # Batch call initial model
        if self.initial_model_group is not None:
            base_action_log_probs_ref = self.initial_model_group.async_run_method_batch(
                method_name="forward",
                sequences=sequences_list,
                action_mask=action_mask_list,
                attention_mask=attention_mask_list,
                llm_mark=[llm_mark] * len(samples_list),
            )

            if args.colocate_all_models or args.colocate_actor_ref:
                ray.get(base_action_log_probs_ref)
                ray.get(self.initial_model_group.async_run_method(method_name="empty_cache"))
        else:
            base_action_log_probs_ref = ray.put(
                [[None]] * (len(samples_list) * args.ring_attn_size * args.ds_tensor_parallel_size)
            )

        # Wait for all remote calls to complete and flatten the results
        # Note: the results duplicated ring_attn_size * ds_tensor_parallel_size times
        duplicate_factor = args.ring_attn_size * args.ds_tensor_parallel_size
        action_log_probs_list = sum(ray.get(action_log_probs_ref)[::duplicate_factor], [])
        base_action_log_probs_list = sum(ray.get(base_action_log_probs_ref)[::duplicate_factor], [])
        value_list = sum(ray.get(value_ref)[::duplicate_factor], [])
        rewards_list = ray.get(r_refs)
        if self.remote_rm_url is None:
            rewards_list = sum(rewards_list[::duplicate_factor], [])
        else:
            rewards_list = torch.cat(rewards_list, dim=0).chunk(len(samples_list))

        assert (
            len(samples_list)
            == len(action_log_probs_list)
            == len(base_action_log_probs_list)
            == len(value_list)
            == len(rewards_list)
        ), f"len(samples_list): {len(samples_list)}, len(action_log_probs_list): {len(action_log_probs_list)}, len(base_action_log_probs_list): {len(base_action_log_probs_list)}, len(value_list): {len(value_list)}, len(rewards_list): {len(rewards_list)}"

        # Process results for each sample
        for i, (samples, action_log_probs, base_action_log_probs, value, rewards) in enumerate(
            zip(samples_list, action_log_probs_list, base_action_log_probs_list, value_list, rewards_list)
        ):
            if (self.initial_model_group is not None) and (not args.use_kl_loss):
                kl = compute_approx_kl(
                    action_log_probs,
                    base_action_log_probs,
                    kl_estimator=self.strategy.args.kl_estimator,
                )
            else:
                kl = torch.zeros_like(action_log_probs, dtype=action_log_probs.dtype, device=device)
            kl_mean = masked_mean(kl, samples.action_mask, dim=-1)

            sequences = samples.sequences
            attention_mask = samples.attention_mask

            if not args.use_kl_loss:
                base_action_log_probs = None

            info = {
                "kl": kl_mean,
                "reward": rewards,
                "response_length": samples.response_length,
                "total_length": samples.total_length,
            }

            experience = Experience(
                sequences,
                action_log_probs,
                base_action_log_probs,
                value,
                None,
                None,
                attention_mask,
                samples.action_mask,
                info,
                kl,
            )

            experiences.append(experience)

        end_time = time.time()
        duration = end_time - start_time
        time_str = str(timedelta(seconds=duration)).split(".")[0]
        logger.info(f"✨ Experience making completed in {time_str}")
        return experiences

    @torch.no_grad()
    def compute_advantages_and_returns(
        self, experiences: List[Experience], experiences_llm2: List[Experience], llm_mark, **kwargs
    ) -> Tuple[List[Experience], List[torch.Tensor]]:
        """
        Process experiences, this can be used to filter out some experiences or do some processing on the rewards.

        Output:
        - experiences: List of Experience
        - rewards: List of rewards
        """
        args = self.strategy.args

        # get rewards from experiences
        rewards = [experience.info["reward"] for experience in experiences]
        rewards_llm2 = [experience.info["reward"] for experience in experiences_llm2]

        # reward shaping
        if args.advantage_estimator == "rloo":
            rewards = torch.cat(rewards).reshape(-1, args.n_samples_per_prompt)
            baseline = (rewards.sum(-1, keepdim=True) - rewards) / (args.n_samples_per_prompt - 1)
            rewards = rewards - baseline
            rewards = rewards.reshape(-1).chunk(len(experiences))
        elif args.advantage_estimator in ["reinforce_baseline", "dr_grpo"]:
            # REINFORCE++-baseline and Dr. GRPO removed the `/std` in GRPO as `/ std` is not needed in RL variance reduction theory.
            # And `k3 KL` has a larger variance than `k1 KL` under a categorical distribution.
            rewards = torch.cat(rewards).reshape(-1, args.n_samples_per_prompt)
            rewards = rewards - rewards.mean(-1, keepdim=True)
            rewards = rewards.reshape(-1).chunk(len(experiences))
        elif args.advantage_estimator == "group_norm":
            rewards = torch.cat(rewards).reshape(-1, args.n_samples_per_prompt)
            rewards = (rewards - rewards.mean(-1, keepdim=True)) / (rewards.std(-1, keepdim=True) + 1e-9)
            rewards = rewards.reshape(-1).chunk(len(experiences))
        elif args.advantage_estimator == "ilr":
            # comparison-aware reward here
            rewards = torch.cat(rewards).reshape(-1, args.n_samples_per_prompt)
            rewards_llm2 = torch.cat(rewards_llm2).reshape(-1, args.n_samples_per_prompt)
            rewards = rewards + ((rewards - rewards_llm2.mean(-1, keepdim=True)) / (rewards_llm2.max(-1, keepdim=True)[0] - rewards_llm2.min(-1, keepdim=True)[0] + 1e-9)).clamp(-1, 1)
            rewards = (rewards - rewards.mean(-1, keepdim=True)) / (rewards.std(-1, keepdim=True) + 1e-9)
            rewards = rewards.reshape(-1).chunk(len(experiences))

        if llm_mark == 'llm1':
            kl_ctl = self.kl_ctl_llm1
        elif llm_mark == 'llm2':
            kl_ctl = self.kl_ctl_llm2

        # calculate return and advantages
        for experience, reward in zip(experiences, rewards):
            reward = compute_reward(
                reward,
                kl_ctl.value,
                experience.kl,
                action_mask=experience.action_mask,
                reward_clip_range=args.reward_clip_range,
            )

            if self.advantage_estimator == "gae":
                experience.advantages, experience.returns = self.get_advantages_and_returns(
                    experience.values,
                    reward,
                    experience.action_mask,
                    args.gamma,
                    args.lambd,
                )
            elif self.advantage_estimator in ["reinforce", "rloo", "reinforce_baseline", "group_norm", "dr_grpo", "ilr"]:
                if args.gamma != 1.0 and self.advantage_estimator in [
                    "rloo",
                    "reinforce_baseline",
                    "group_norm",
                    "dr_grpo",
                    "ilr",
                ]:
                    logger.warning("gamma is set to 1.0 for rloo, reinforce_baseline, and group_norm")
                    args.gamma = 1.0

                experience.returns = self.get_cumulative_returns(
                    reward,
                    experience.action_mask,
                    args.gamma,
                )
                experience.advantages = deepcopy(experience.returns)
            else:
                raise Exception(f"Unkown advantage_estimator {self.advantage_estimator}")

            # calculate the return info.
            return_sums = reward.sum(dim=-1)
            experience.info["return"] = return_sums
            # remove unnecessary info
            experience.kl = None

        # Normalize advantages across all experiences
        if self.args.advantage_estimator not in ["group_norm", "dr_grpo"]:
            all_advantages = []
            all_action_masks = []
            for exp in experiences:
                all_advantages.append(exp.advantages)
                all_action_masks.append(exp.action_mask)

            advantages_vector = zero_pad_sequences(all_advantages).float().flatten()
            action_masks_vector = zero_pad_sequences(all_action_masks).flatten()
            num_actions = action_masks_vector.sum()

            # mean
            mean = (advantages_vector * action_masks_vector).sum() / num_actions
            # std
            if not self.args.no_advantage_std_norm:
                var = ((advantages_vector - mean).pow(2) * action_masks_vector).sum() / num_actions
                rstd = var.clamp(min=1e-8).rsqrt()
            else:
                rstd = 1

            # Apply normalization to each experience
            for exp in experiences:
                exp.advantages = (exp.advantages - mean) * rstd

        return experiences

    @torch.no_grad()
    def get_advantages_and_returns(
        self,
        values: torch.Tensor,
        rewards: torch.Tensor,
        action_mask: torch.Tensor,
        gamma: float,
        lambd: float,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Function that computes advantages and returns from rewards and values.
        Calculated as in the original PPO paper: https://arxiv.org/abs/1707.06347
        Note that rewards may include a KL divergence loss term.

        Advantages looks like this:
        Adv1 =  R1 + γ * λ * R2     + γ^2 * λ^2 * R3       + ...
              - V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ...

        Returns looks like this:
        Ret1 =  R1 + γ * λ * R2     + γ^2 * λ^2 * R3       + ...
                   + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ...

        Input:
        - values: Tensor of shape (batch_size, response_size)
        - rewards: Tensor of shape (batch_size, response_size)

        Output:
        - advantages: Tensor of shape (batch_size, response_size)
        - returns: Tensor of shape (batch_size, response_size)
        """
        lastgaelam = 0
        advantages_reversed = []
        response_length = rewards.size(1)

        # Mask invalid responses
        if action_mask is not None:
            values = action_mask * values
            rewards = action_mask * rewards

        for t in reversed(range(response_length)):
            nextvalues = values[:, t + 1] if t < response_length - 1 else 0.0
            delta = rewards[:, t] + gamma * nextvalues - values[:, t]
            lastgaelam = delta + gamma * lambd * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)
        returns = advantages + values
        return advantages.detach(), returns

    @torch.no_grad()
    def get_cumulative_returns(
        self,
        rewards: torch.Tensor,
        action_mask: torch.Tensor,
        gamma: float,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Function that computes advantages and returns from rewards using REINFORCE.
        REINFORCE uses cumulative returns without the GAE (Generalized Advantage Estimation).

        Input:
        - rewards: Tensor of shape (batch_size, response_size)
        - action_mask: Tensor of shape (batch_size, response_size), binary mask
        - gamma: discount factor

        Output:
        - returns: Tensor of shape (batch_size, response_size)
        """
        response_length = rewards.size(1)
        returns = torch.zeros_like(rewards)
        cumulative_return = torch.zeros(rewards.size(0), device=rewards.device)

        # Mask invalid responses if action_mask is provided
        if action_mask is not None:
            rewards = action_mask * rewards

        # Calculate returns by accumulating discounted rewards
        for t in reversed(range(response_length)):
            cumulative_return = rewards[:, t] + gamma * cumulative_return
            returns[:, t] = cumulative_return

        return returns

    @torch.no_grad()
    def generate_samples(self, all_prompts_llm1: List[str], all_labels_llm1, all_levels_llm1, all_prompts_llm2: List[str], all_labels_llm2, all_levels_llm2, **generate_kwargs) -> List[Samples]:
        """
        Generate samples and return in batches.

        When not using vllm, we will fallback to the default implementation,
        in which actor will be used to generate samples.
        """
        if self.vllm_engines_llm1 is None:
            return self._generate_with_hf(all_prompts, all_labels, **generate_kwargs)

        # vLLM generation
        return self._generate_vllm(all_prompts_llm1, all_labels_llm1, all_levels_llm1, all_prompts_llm2, all_labels_llm2, all_levels_llm2, **generate_kwargs)

    @torch.no_grad()
    def _generate_with_hf(self, all_prompts: List[str], all_labels, **generate_kwargs) -> List[Samples]:
        raise NotImplementedError("HF generation is not implemented yet")

    def IRF(self, difficulty, model1_ability, model2_ability, limbda=0.5):
        model1_pcan = 1 / (1 + math.exp(-1.7 * (model1_ability - difficulty)))
        model2_pcan = 1 / (1 + math.exp(-1.7 * (model2_ability - difficulty)))
        pcan = (model1_pcan + model2_pcan) / 2
        pcannot = 1 - pcan
        result = limbda * pcan - (1-limbda) * pcannot
        if result > 0:
            return 'competition'
        else:
            return 'cooperation'

    def generate_with_input(self, llms, all_prompts, sampling_params, llm_mark):
        if llm_mark == 'llm1':
            all_prompts = self.tokenizer_llm1.apply_chat_template(all_prompts, tokenize=False, add_generation_prompt=True)
        elif llm_mark == 'llm2':
            all_prompts = self.tokenizer_llm2.apply_chat_template(all_prompts, tokenize=False, add_generation_prompt=True)
        all_prompt_token_ids = self.tokenize_fn(all_prompts, self.prompt_max_len, padding=False, llm_mark=llm_mark)["input_ids"]
        
        # vLLM wakeup when vllm_enable_sleep
        if self.strategy.args.vllm_enable_sleep:
            batch_vllm_engine_call(llms, "wake_up")

        # Distribute requests to engines and collect responses to outputs
        refs = []
        batch_size = (len(all_prompt_token_ids) + len(llms) - 1) // len(llms)
        for i, llm in enumerate(llms):
            prompt_token_ids = all_prompt_token_ids[i * batch_size : (i + 1) * batch_size]
            refs.append(llm.add_requests.remote(0, sampling_params=sampling_params, prompt_token_ids=prompt_token_ids))
        ray.get(refs)

        # Retrieve and combine results from all outputs
        all_output_refs = []
        for i, llm in enumerate(llms):
            all_output_refs.append(llm.get_responses.remote(0))
        all_outputs = sum(ray.get(all_output_refs), [])
        
        # vLLM offload when vllm_enable_sleep
        if self.strategy.args.vllm_enable_sleep:
            batch_vllm_engine_call(llms, "sleep")
        return all_outputs

#     def generate_with_input(self, llm1s, all_prompts_llm1, llm2s, all_prompts_llm2, sampling_params):
#         all_prompts_llm1 = self.tokenizer_llm1.apply_chat_template(all_prompts_llm1, tokenize=False, add_generation_prompt=True)
#         all_prompts_llm2 = self.tokenizer_llm2.apply_chat_template(all_prompts_llm2, tokenize=False, add_generation_prompt=True)
#         all_prompt_token_ids_llm1 = self.tokenize_fn(all_prompts_llm1, self.prompt_max_len, padding=False, llm_mark='llm1')["input_ids"]
#         all_prompt_token_ids_llm2 = self.tokenize_fn(all_prompts_llm2, self.prompt_max_len, padding=False, llm_mark='llm2')["input_ids"]

#         # Distribute requests to engines and collect responses to outputs
#         refs = []
#         batch_size = (len(all_prompt_token_ids_llm1) + len(llm1s) - 1) // len(llm1s)
#         for i, llm in enumerate(llm1s):
#             prompt_token_ids_llm1 = all_prompt_token_ids_llm1[i * batch_size : (i + 1) * batch_size]
#             refs.append(llm.add_requests.remote(0, sampling_params=sampling_params, prompt_token_ids=prompt_token_ids_llm1))
#         for i, llm in enumerate(llm2s):
#             prompt_token_ids_llm2 = all_prompt_token_ids_llm2[i * batch_size : (i + 1) * batch_size]
#             refs.append(llm.add_requests.remote(0, sampling_params=sampling_params, prompt_token_ids=prompt_token_ids_llm2))
#         ray.get(refs)

#         # Retrieve and combine results from all outputs
#         all_output_refs = []
#         for i, llm in enumerate(llm1s):
#             all_output_refs.append(llm.get_responses.remote(0))
#         for i, llm in enumerate(llm2s):
#             all_output_refs.append(llm.get_responses.remote(0))
#         all_outputs = sum(ray.get(all_output_refs), [])
#         all_outputs_llm1 = all_outputs[:len(all_outputs)//2]
#         all_outputs_llm2 = all_outputs[len(all_outputs)//2:]
# #         print('1'*50)
# #         print(len(all_outputs_llm1), len(all_outputs_llm2))
#         return all_outputs_llm1, all_outputs_llm2

    def process_sample_list(self, all_outputs, all_prompts, all_labels, args, llm_mark):
        if llm_mark == 'llm1':
            pad_token_id, eos_token_id = self.tokenizer_llm1.pad_token_id, self.tokenizer_llm1.eos_token_id
            tokenizer = self.tokenizer_llm1
        elif llm_mark == 'llm2':
            pad_token_id, eos_token_id = self.tokenizer_llm2.pad_token_id, self.tokenizer_llm2.eos_token_id
            tokenizer = self.tokenizer_llm2
        
        samples_list = []
        # Group outputs by micro_rollout_batch_size
        for i in range(0, len(all_outputs), args.micro_rollout_batch_size):
            batch_outputs = all_outputs[i : i + args.micro_rollout_batch_size]
            batch_prompts = all_prompts[i : i + args.micro_rollout_batch_size]
            batch_labels = all_labels[i : i + args.micro_rollout_batch_size]

            # Calculate max lengths for this batch only
            batch_max_input_len_list = []
            for prompt in batch_prompts:
                # left padding input
                prompt = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
                input_prompt_token_ids = self.tokenize_fn(prompt, self.prompt_max_len, padding=False, llm_mark=llm_mark)["input_ids"]
                batch_max_input_len_list.append(len(input_prompt_token_ids))

            batch_max_input_len = max(batch_max_input_len_list)
            batch_max_output_len = max(len(output.outputs[0].token_ids) for output in batch_outputs)

            sequences = []
            for output, prompt in zip(batch_outputs, batch_prompts):
                # left padding input
                prompt = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
                input_prompt_token_ids = self.tokenize_fn(prompt, self.prompt_max_len, padding=False, llm_mark=llm_mark)["input_ids"]
                input_len = len(input_prompt_token_ids)
                input_ids = [pad_token_id] * (batch_max_input_len - input_len) + list(input_prompt_token_ids)

                # right padding output
                output_len = len(output.outputs[0].token_ids)
                output_ids = list(output.outputs[0].token_ids) + [pad_token_id] * (batch_max_output_len - output_len)

                # concat input and output
                sequences.append(input_ids + output_ids)

            sequences = torch.tensor(sequences)
            # print(eos_token_id, pad_token_id)
            sequences, attention_mask, action_mask = process_sequences(
                sequences, batch_max_input_len, eos_token_id, pad_token_id
            )
            sequences = sequences.to("cpu")
            attention_mask = attention_mask.to("cpu")
            action_mask = action_mask.to("cpu")
            response_length = action_mask.float().sum(dim=-1)
            total_length = attention_mask.float().sum(dim=-1)

            rollout_samples = Samples(
                sequences=sequences,
                attention_mask=attention_mask,
                action_mask=action_mask,
                response_length=response_length,
                total_length=total_length,
                prompts=batch_prompts,
                labels=batch_labels,
            )
            samples_list.append(rollout_samples)

        return samples_list



    def _generate_vllm(self, all_prompts_llm1_input, all_labels_llm1_input, all_levels_llm1, all_prompts_llm2_input, all_labels_llm2_input, all_levels_llm2, **kwargs) -> List[Samples]:
        from vllm import SamplingParams

        llms_1 = self.vllm_engines_llm1
        llms_2 = self.vllm_engines_llm2
        args = self.strategy.args

        sampling_params = SamplingParams(
            temperature=kwargs.get("temperature", 0.5),
            top_p=kwargs.get("top_p", 1.0),
            top_k=kwargs.get("top_k", -1),
            max_tokens=kwargs.get("max_new_tokens", 2048),
            min_tokens=kwargs.get("min_new_tokens", 1),
            skip_special_tokens=kwargs.get("skip_special_tokens", False),
            include_stop_str_in_output=True,
        )

        # Expand prompt list based on the number of samples per prompt
        n_samples_per_prompt = kwargs.pop("n_samples_per_prompt", args.n_samples_per_prompt)
        all_prompts_llm1=[]
        all_labels_llm1=[]
        all_prompts_llm2=[]
        all_labels_llm2=[]
        all_prompts_llm1_output=sum([[[{"role": "user", "content": prompt_input}]] * args.n_samples_per_prompt for prompt_input in all_prompts_llm1_input], [])
        all_prompts_llm2_output=sum([[[{"role": "user", "content": prompt_input}]] * args.n_samples_per_prompt for prompt_input in all_prompts_llm2_input], [])

        interactions = []
        for level_llm1, level_llm2 in zip(all_levels_llm1, all_levels_llm2):
            assert level_llm1 == level_llm2
            ability_llm1 = ray.get(self.actor_model_group.async_run_method(method_name="get_ability", llm_mark='llm1'))[0]
            ability_llm2 = ray.get(self.actor_model_group.async_run_method(method_name="get_ability", llm_mark='llm2'))[0]
            interaction = self.IRF(float(level_llm1), float(ability_llm1), float(ability_llm2))
            for _ in range(args.n_samples_per_prompt):
                interactions.append(interaction)
        # Please just output your step-by-step reasoning solution.
        interaction_prompts = {
            "competition":{
                "Idea_sharing": 
                    '''Question: {question}
                    Please reasoning step by step, and put your final answer within \\boxed{{}}. Use English to output your content.''',

                "Idea_fusion": 
                    '''Opponent's Solution: {ideas}
                    Critically analyze the opponent's ideas, identify the weakness and strength of his ideas. Use English to output your content.''',

                "Idea_innovation": 
                    '''Based on the aboved analysis, give an updated answer to Original Question: {question}.  Please reasoning step by step, and put your final answer within \\boxed{{}}. Use English to output your content.'''
            },
            "cooperation":{
                "Idea_sharing": 
                    '''Question: {question}
                     Please reasoning step by step, and put your final answer within \\boxed{{}}. Use English to output your content.''',

                "Idea_fusion": 
                    '''Partner's Contribution: {ideas}
                    Collaboratively analyze the key steps in the partner's contribution, identify those steps that can help you to improve your answer and serve as the additional advice. Use English to output your content.''',

                "Idea_innovation": 
                    '''Based on the aboved analysis, give an updated answer to Original Question: {question}.  Please reasoning step by step, and put your final answer within \\boxed{{}}. Use English to output your content.'''
            }
        }


         # Idea sharing
        for i, prompt in enumerate(all_prompts_llm1_input):
            for j in range(args.n_samples_per_prompt):
                all_prompts_llm1 += [[{"role": "user", "content": interaction_prompts[interactions[i*args.n_samples_per_prompt+j]]["Idea_sharing"].format(question=prompt)}]]
        for label in all_labels_llm1_input:
            for _ in range(args.n_samples_per_prompt):
                all_labels_llm1 += [label]
        for i, prompt in enumerate(all_prompts_llm2_input):
            for j in range(args.n_samples_per_prompt):
                all_prompts_llm2 += [[{"role": "user", "content": interaction_prompts[interactions[i*args.n_samples_per_prompt+j]]["Idea_sharing"].format(question=prompt)}]]
        for label in all_labels_llm2_input:
            for _ in range(args.n_samples_per_prompt):
                all_labels_llm2 += [label] 
            

        initial_output_llm1 = self.generate_with_input(llms_1, all_prompts_llm1, sampling_params, 'llm1')
        initial_output_llm2 = self.generate_with_input(llms_2, all_prompts_llm2, sampling_params, 'llm2')
        assert len(initial_output_llm1) == len(initial_output_llm2)
        
        if args.interaction =='one':
            samples_list_llm1 = self.process_sample_list(initial_output_llm1, all_prompts_llm1_output, all_labels_llm1, args, 'llm1')
            samples_list_llm2 = self.process_sample_list(initial_output_llm2, all_prompts_llm2_output, all_labels_llm2, args, 'llm2')

            return samples_list_llm1, samples_list_llm2
            
        elif args.interaction =='idea':
            # Idea fusion
            for i in range(len(initial_output_llm1)):
                ideas_llm1 = initial_output_llm1[i].outputs[0].text
                ideas_llm2 = initial_output_llm2[i].outputs[0].text
                all_prompts_llm1[i].append({'role': 'assistant', 'content': ideas_llm1})
                all_prompts_llm1[i].append({'role': 'user', 'content': interaction_prompts[interactions[i]]["Idea_fusion"].format(ideas=ideas_llm2)})
                all_prompts_llm2[i].append({'role': 'assistant', 'content': ideas_llm2})
                all_prompts_llm2[i].append({'role': 'user', 'content': interaction_prompts[interactions[i]]["Idea_fusion"].format(ideas=ideas_llm1)})

            discussion_llm1 = self.generate_with_input(llms_1, all_prompts_llm1, sampling_params, 'llm1')
            discussion_llm2 = self.generate_with_input(llms_2, all_prompts_llm2, sampling_params, 'llm2')
            assert len(discussion_llm1) == len(discussion_llm2)


            # Idea innovation
            for i in range(len(discussion_llm1)):
                question_idx = i // args.n_samples_per_prompt
                ideas_llm1 = discussion_llm1[i].outputs[0].text
                ideas_llm2 = discussion_llm2[i].outputs[0].text
                all_prompts_llm1[i].append({'role': 'assistant', 'content': ideas_llm1})
                all_prompts_llm1[i].append({'role': 'user', 'content': interaction_prompts[interactions[i]]["Idea_innovation"].format(question=all_prompts_llm1_input[question_idx])})
                all_prompts_llm2[i].append({'role': 'assistant', 'content': ideas_llm2})
                all_prompts_llm2[i].append({'role': 'user', 'content': interaction_prompts[interactions[i]]["Idea_innovation"].format(question=all_prompts_llm2_input[question_idx])})

            all_outputs_llm1 = self.generate_with_input(llms_1, all_prompts_llm1, sampling_params, 'llm1')
            all_outputs_llm2 = self.generate_with_input(llms_2, all_prompts_llm2, sampling_params, 'llm2')
            assert len(all_outputs_llm1) == len(all_outputs_llm2)

    #         discussion_llm1 = self.generate_with_input(llms_1, all_prompts_llm1, sampling_params, 'llm1')
    #         discussion_llm2 = self.generate_with_input(llms_2, all_prompts_llm2, sampling_params, 'llm2')

    #         for i in range(len(discussion_llm1)):
    #             question_idx = i // args.n_samples_per_prompt 
    #             ideas_llm1 = discussion_llm1[i].outputs[0].text
    #             ideas_llm2 = discussion_llm2[i].outputs[0].text
    #             all_prompts_llm1[i].append({'role': 'assistant', 'content': ideas_llm1})
    #             all_prompts_llm1[i].append({'role': 'user', 'content': "Based on the above analysis, can you solve the problem:{question}\n\nPlease reasoning step by step, and your final answer should be in the form \\boxed{{answer}} given at the end of your response.".format(question=all_prompts_llm1_input[question_idx])})
    #             all_prompts_llm2[i].append({'role': 'assistant', 'content': ideas_llm2})
    #             all_prompts_llm2[i].append({'role': 'user', 'content': "Based on the above analysis, can you solve the problem:{question}\n\nPlease reasoning step by step, and your final answer should be in the form \\boxed{{answer}} given at the end of your response.".format(question=all_prompts_llm2_input[question_idx])})

    #         all_outputs_llm1 = self.generate_with_input(llms_1, all_prompts_llm1, sampling_params, 'llm1')
    #         all_outputs_llm2 = self.generate_with_input(llms_2, all_prompts_llm2, sampling_params, 'llm2')
    #         assert len(all_outputs_llm1) == len(all_outputs_llm2)

            # Using label to select the best answer from initial and final answers
            oracle_outputs_llm1 = []
            oracle_outputs_llm2 = []
            for label_llm1, label_llm2, initial_answer_llm1, initial_answer_llm2, final_answer_llm1, final_answer_llm2 in zip(all_labels_llm1, all_labels_llm2, initial_output_llm1, initial_output_llm2, all_outputs_llm1, all_outputs_llm2):
                assert label_llm1 == label_llm2
                extracted_initial_answer_llm1 = ray.get(self.actor_model_group.async_run_method(method_name="extract_answer", pred_str=initial_answer_llm1.outputs[0].text))[0]
                extracted_initial_answer_llm2 = ray.get(self.actor_model_group.async_run_method(method_name="extract_answer", pred_str=initial_answer_llm2.outputs[0].text))[0]
                extracted_label = ray.get(self.actor_model_group.async_run_method(method_name="extract_answer", pred_str=label_llm1))[0]
                extracted_final_answer_llm1 = ray.get(self.actor_model_group.async_run_method(method_name="extract_answer", pred_str=final_answer_llm1.outputs[0].text))[0]
                extracted_final_answer_llm2 = ray.get(self.actor_model_group.async_run_method(method_name="extract_answer", pred_str=final_answer_llm2.outputs[0].text))[0]
    #             print(extracted_label, extracted_initial_answer_llm1, extracted_final_answer_llm1, extracted_initial_answer_llm2, extracted_final_answer_llm2)

                if extracted_label == extracted_final_answer_llm1:
                    oracle_outputs_llm1.append(final_answer_llm1)
                elif extracted_label == extracted_initial_answer_llm1:
                    oracle_outputs_llm1.append(initial_answer_llm1)
                else:
                    oracle_outputs_llm1.append(final_answer_llm1)

                if extracted_label == extracted_final_answer_llm2:
                    oracle_outputs_llm2.append(final_answer_llm2)
                elif extracted_label == extracted_initial_answer_llm2:
                    oracle_outputs_llm2.append(initial_answer_llm2)
                else:
                    oracle_outputs_llm2.append(final_answer_llm2)

            assert len(oracle_outputs_llm1) == len(oracle_outputs_llm2) == len(all_outputs_llm1) == len(all_outputs_llm2)

    #         print(oracle_outputs_llm1)
    #         print(oracle_outputs_llm2)

            samples_list_llm1 = self.process_sample_list(oracle_outputs_llm1, all_prompts_llm1_output, all_labels_llm1, args, 'llm1')
            samples_list_llm2 = self.process_sample_list(oracle_outputs_llm2, all_prompts_llm2_output, all_labels_llm2, args, 'llm2')

            return samples_list_llm1, samples_list_llm2

