import math
import os
import socket
from abc import ABC
from typing import Dict, List, Optional, Union

import deepspeed
import ray
import torch
import torch.distributed
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers.trainer import get_scheduler

from openrlhf.models import ilr_Actor, PolicyLoss
from openrlhf.models.utils import compute_approx_kl, masked_mean
from openrlhf.trainer.ppo_utils.ilr_experience_maker import Experience
from openrlhf.utils import get_tokenizer
from openrlhf.utils.deepspeed import ILRDeepspeedStrategy
from openrlhf.utils.deepspeed.deepspeed_utils import offload_deepspeed_states, reload_deepspeed_states
from openrlhf.utils.distributed_util import init_process_group, torch_dist_barrier_and_cuda_sync
from openrlhf.utils.logging_utils import init_logger
import copy
from ..ppo_utils import NaiveReplayBuffer

logger = init_logger(__name__)

from .launcher import BasePPORole
from .utils import get_physical_gpu_id

import inspect
import copy
import torch.nn.functional as F
import json
from tqdm import tqdm
import re
from word2number import w2n
import time
import ray
from vllm import LLM, SamplingParams
import regex
from sympy.parsing.sympy_parser import parse_expr
from sympy.parsing.latex import parse_latex
from latex2sympy2 import latex2sympy
import numpy as np


class ActorILRTrainer(ABC):
    def __init__(
        self,
        strategy_llm1,
        strategy_llm2,
        actor_llm1: ilr_Actor,
        actor_llm2: ilr_Actor,
        ema_model: ilr_Actor,
        actor_optim_llm1: Optimizer,
        actor_optim_llm2: Optimizer,
        actor_scheduler_llm1,
        actor_scheduler_llm2,
        ema_beta: float = 0.992,
        micro_train_batch_size: int = 8,
        buffer_limit: int = 0,
        buffer_cpu_offload: bool = True,
        eps_clip: float = 0.2,
        tokenizer_llm1=None,
        tokenizer_llm2=None,
        dataloader_pin_memory: bool = True,
        vllm_engines_llm1: List = None,
        vllm_engines_llm2: List = None,
        **kwargs,
    ):
        """PPOTrainer for ray.

        Args:
            vllm_engines (List, optional): vllm engines for text generation, if not specified, generate text by actor model directly. Defaults to None.
        """
        self.strategy_llm1 = strategy_llm1
        self.strategy_llm2 = strategy_llm2
        self.args = strategy_llm1.args
        self.tokenizer_llm1 = tokenizer_llm1
        self.tokenizer_llm2 = tokenizer_llm2
        self.generate_kwargs = kwargs
        self.dataloader_pin_memory = dataloader_pin_memory
        self.micro_train_batch_size = micro_train_batch_size
        self.ema_beta = ema_beta

        self.actor_llm1 = actor_llm1
        self.actor_llm2 = actor_llm2
        self.ema_model = ema_model
        self.actor_optim_llm1 = actor_optim_llm1
        self.actor_optim_llm2 = actor_optim_llm2
        self.actor_scheduler_llm1 = actor_scheduler_llm1
        self.actor_scheduler_llm2 = actor_scheduler_llm2
        self.vllm_engines_llm1 = vllm_engines_llm1
        self.vllm_engines_llm2 = vllm_engines_llm2
        self.max_epochs = self.args.max_epochs

        self.actor_loss_fn = PolicyLoss(eps_clip)

        # Mixtral 8x7b
        self.aux_loss = self.args.aux_loss_coef > 1e-8

        self.replay_buffer_llm1 = NaiveReplayBuffer(
            micro_train_batch_size, buffer_limit, buffer_cpu_offload, getattr(self.args, "packing_samples", False), tokenizer=tokenizer_llm1
        )
        self.replay_buffer_llm2 = NaiveReplayBuffer(
            micro_train_batch_size, buffer_limit, buffer_cpu_offload, getattr(self.args, "packing_samples", False), tokenizer=tokenizer_llm2
        )

        # Init torch group for weights sync
        backend = getattr(self.args, "vllm_sync_backend", "nccl")
        self.use_cuda_ipc = False
        if backend == "nccl" and self.args.colocate_all_models:
            self.use_cuda_ipc = True

        # Create torch group with deepspeed rank 0 and all vllm ranks
        # to update vllm engine's weights after each training stage.
        #
        # Say we have 3 vllm engines and eache of them has 4 GPUs,
        # then the torch group is:
        # [    0,      1, 2, 3, 4,  5, 6, 7, 8,  9, 10, 11, 12]
        # |ds rank 0 |  engine-0  |  engine-1  |   engine-2   |
        #
        # For ZeRO-1/2:
        #   1. Broadcast parameters from rank 0 to all vllm engines
        # For ZeRO-3:
        #   1. AllGather paramters to rank 0
        #   2. Broadcast parameters from rank 0 to all vllm engines
        if self.vllm_engines_llm1 is not None and self.vllm_engines_llm2 is not None and not self.use_cuda_ipc and torch.distributed.get_rank() == 0:
            self.prepare_vllm('llm1', backend)
            self.prepare_vllm('llm2', backend)

        torch_dist_barrier_and_cuda_sync()

    def prepare_vllm(self, llm_mark, backend):
        if llm_mark == 'llm1':
            strategy = self.strategy_llm1
            group_name = "openrlhf_llm1"
            vllm_engines = self.vllm_engines_llm1
        elif llm_mark == 'llm2':
            strategy = self.strategy_llm2
            group_name = "openrlhf_llm2"
            vllm_engines = self.vllm_engines_llm2
        master_address = ray._private.services.get_node_ip_address()
        with socket.socket() as sock:
            sock.bind(("", 0))
            master_port = sock.getsockname()[1]
        print(master_address, master_port)


        vllm_num_engines, vllm_tensor_parallel_size = (
            strategy.args.vllm_num_engines,
            strategy.args.vllm_tensor_parallel_size,
        )
        world_size = vllm_num_engines * vllm_tensor_parallel_size + 1

        use_ray = getattr(strategy.args, "vllm_sync_with_ray", False)
        refs = [
            engine.init_process_group.remote(
                master_address,
                master_port,
                i * vllm_tensor_parallel_size + 1,
                world_size,
                group_name,
                backend=backend,
                use_ray=use_ray,
            )
            for i, engine in enumerate(vllm_engines)
        ]
        
    
        if use_ray:
            import ray.util.collective as collective

            collective.init_collective_group(world_size=world_size, rank=0, backend=backend, group_name=group_name)
            model_update_group = group_name
        else:
            model_update_group = init_process_group(
                backend=backend,
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=0,
                group_name=group_name,
            )
            
        
        if llm_mark == 'llm1':
            self._model_update_group_llm1 = model_update_group
        elif llm_mark == 'llm2':
            self._model_update_group_llm2 = model_update_group
        
        ray.get(refs)

    def ppo_train(self, kl_ctl: float, llm_mark):
        if llm_mark == 'llm1':
            replay_buffer = self.replay_buffer_llm1
            strategy = self.strategy_llm1
            tokenizer = self.tokenizer_llm1
        elif llm_mark == 'llm2':
            replay_buffer = self.replay_buffer_llm2
            strategy = self.strategy_llm2
            tokenizer = self.tokenizer_llm2
        # replay buffer may be empty at first, we should rebuild at each training
        dataloader = DataLoader(
            replay_buffer,
            batch_size=replay_buffer.sample_batch_size,
            shuffle=False if strategy.ring_attn_group is not None else True,
            drop_last=True,
            pin_memory=self.dataloader_pin_memory,
            collate_fn=replay_buffer.collate_fn,
        )
        device = torch.cuda.current_device()

        status_list = []
        status_mean = {}
        for epoch in range(self.max_epochs):
            pbar = tqdm(
                dataloader,
                desc=f"{llm_mark} Train epoch [{epoch + 1}/{self.max_epochs}]",
                disable=not strategy.is_rank_0(),
            )
            for experience in pbar:
                experience.to_device(device)
                status = self.training_step(experience, kl_ctl, llm_mark)
                if status['policy_loss'] == 0:
                    sample_error = tokenizer.batch_decode(
                        experience.sequences[0].unsqueeze(0), skip_special_tokens=True
                    )
                    print(experience.sequences[0].unsqueeze(0))
                    print(sample_error)


                status["kl"] *= status["response_length"]
                status = strategy.all_reduce(status)
                status["kl"] /= status["response_length"]

                short_status = {
                    "act_loss": status["policy_loss"],
                    "reward": status["reward"],
                    "return": status["return"],
                    "gen_len": status["response_length"],
                    "tot_len": status["total_length"],
                    "kl": status["kl"],
                    "act_lr": status["actor_lr"],
                }

                if "entropy_loss" in status:
                    short_status["ent_loss"] = status["entropy_loss"]

                status_list.append(status)
                pbar.set_postfix(short_status)
                

        if status_list:
            status_mean = status_list[0]
            for m in status_list[1:]:
                for k, v in m.items():
                    status_mean[k] += v
            for k in status_mean.keys():
                status_mean[k] /= len(status_list)
        return status_mean

    def training_step(self, experience: Experience, kl_ctl: float, llm_mark) -> Dict[str, float]:
        if llm_mark == 'llm1':
            actor = self.actor_llm1
            actor_optim = self.actor_optim_llm1
            actor_scheduler = self.actor_scheduler_llm1
            strategy = self.strategy_llm1
            init_kl_coef = strategy.args.init_kl_coef_llm1
        elif llm_mark == 'llm2':
            actor = self.actor_llm2
            actor_optim = self.actor_optim_llm2
            actor_scheduler = self.actor_scheduler_llm2
            strategy = self.strategy_llm2
            init_kl_coef = strategy.args.init_kl_coef_llm2
        
        actor.train()
        sequences = experience.sequences
        action_mask = experience.action_mask
        attention_mask = experience.attention_mask
        packed_seq_lens = None
        old_action_log_probs = experience.action_log_probs
        advantages = experience.advantages
        base_action_log_probs = experience.base_action_log_probs

        # actor loss
        action_log_probs, output = actor(
            sequences,
            action_mask,
            attention_mask=attention_mask,
            return_output=True,
            ring_attn_group=strategy.ring_attn_group,
            packed_seq_lens=packed_seq_lens,
            return_entropy=self.args.entropy_loss_coef > 1e-8,
        )

        # loss function
        actor_loss = self.actor_loss_fn(
            action_log_probs,
            old_action_log_probs,
            advantages,
            action_mask=experience.action_mask,
        )

        if self.args.use_kl_loss:
            if init_kl_coef > 0:
                kl = compute_approx_kl(
                    action_log_probs,
                    base_action_log_probs,
                    kl_estimator=self.args.kl_estimator,
                )
            else:
                kl = torch.zeros_like(action_log_probs, dtype=action_log_probs.dtype, device=action_log_probs.device)
            kl_loss = masked_mean(kl, experience.action_mask)
            experience.info["kl"] = kl_loss.detach()
        else:
            kl_loss = 0

        loss = actor_loss + kl_loss * kl_ctl
        # mixtral
        if self.aux_loss:
            loss += output.aux_loss * self.args.aux_loss_coef
        # entropy loss
        if self.args.entropy_loss_coef > 1e-8:
            entropy_loss = masked_mean(output.entropy[:, -experience.action_mask.shape[1] :], experience.action_mask)
            loss -= entropy_loss * self.args.entropy_loss_coef

        strategy.backward(loss, actor, actor_optim)
        strategy.optimizer_step(actor_optim, actor, actor_scheduler, name=f"{llm_mark}_actor")
        if self.ema_model:
            self.strategy.moving_average(self.actor, self.ema_model, self.ema_beta, "cuda")

        # status
        status = {"policy_loss": actor_loss.detach().item(), "actor_lr": actor_scheduler.get_last_lr()[0]}

        if self.args.entropy_loss_coef > 1e-8:
            status["entropy_loss"] = entropy_loss.detach().item()
        for k, v in experience.info.items():
            status[k] = v.mean().item()
        return status

    def _broadcast_to_vllm(self, llm_mark):
        use_prefix_cache = getattr(self.strategy_llm1.args, "enable_prefix_caching", False)
        cache_reset_refs = []
        if use_prefix_cache and torch.distributed.get_rank() == 0:
            # clear prefix cache
            for engine in vllm_engines:
                cache_reset_refs.append(engine.reset_prefix_cache.remote())

        torch.cuda.empty_cache()
        
        if llm_mark == 'llm1':
            vllm_engines = self.vllm_engines_llm1
            model = self.actor_llm1.model.module
            strategy = self.strategy_llm1
        elif llm_mark == 'llm2':
            vllm_engines = self.vllm_engines_llm2
            model = self.actor_llm2.model.module
            strategy = self.strategy_llm2
        
        
        count, num_params = 0, len(list(model.named_parameters()))

        def _broadcast_param(param, count, num_params, llm_mark):
            # Fire all vllm engines for broadcast
            if torch.distributed.get_rank() == 0:
                if llm_mark == 'llm1':
                    vllm_engines = self.vllm_engines_llm1
                    model_update_group = self._model_update_group_llm1
                    strategy = self.strategy_llm1
                elif llm_mark == 'llm2':
                    vllm_engines = self.vllm_engines_llm2
                    model_update_group = self._model_update_group_llm2
                    strategy = self.strategy_llm2

                use_ray = getattr(strategy.args, "vllm_sync_with_ray", False)

                shape = param.shape if strategy.args.zero_stage != 3 else param.ds_shape
                refs = [
                    engine.update_weight.remote(name, dtype=param.dtype, shape=shape, empty_cache=count == num_params)
                    for engine in vllm_engines
                ]

                if use_ray:
                    import ray.util.collective as collective

                    collective.broadcast(param.data, 0, group_name=model_update_group)
                else:
                    torch.distributed.broadcast(param.data, 0, group=model_update_group)
                ray.get(refs)


        def _handle_cuda_ipc(param, count, num_params, llm_mark):
            if llm_mark == 'llm1':
                vllm_engines = self.vllm_engines_llm1
                strategy = self.strategy_llm1
            elif llm_mark == 'llm2':
                vllm_engines = self.vllm_engines_llm2
                strategy = self.strategy_llm2

            from torch.multiprocessing.reductions import reduce_tensor

            weight = param.data.clone()
            ipc_handle = reduce_tensor(weight)

            ipc_handle = {get_physical_gpu_id(): ipc_handle}
            ipc_handle_list = [None] * torch.distributed.get_world_size()
            torch.distributed.all_gather_object(ipc_handle_list, ipc_handle)

            if torch.distributed.get_rank() == 0:
                ipc_handles = {}
                for d in ipc_handle_list:
                    ipc_handles.update(d)

                shape = param.shape if strategy.args.zero_stage != 3 else param.ds_shape
                refs = [
                    engine.update_weight_cuda_ipc.remote(
                        name,
                        dtype=param.dtype,
                        shape=shape,
                        ipc_handles=ipc_handles,
                        empty_cache=count == num_params,
                    )
                    for engine in vllm_engines
                ]
                ray.get(refs)
            torch_dist_barrier_and_cuda_sync()

        for name, param in model.named_parameters():
            count += 1  # empty_cache at last param

            # broadcast
            if not self.use_cuda_ipc:
                # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0
                if strategy.args.ds_tensor_parallel_size > 1:
                    with deepspeed.module_inject.layers.GatherReplacedLayerParams([param], model, enabled=True):
                        _broadcast_param(param, count, num_params, llm_mark)
                else:
                    with deepspeed.zero.GatheredParameters([param], enabled=strategy.args.zero_stage == 3):
                        _broadcast_param(param, count, num_params, llm_mark)
            # CUDA IPC
            else:
                if strategy.args.ds_tensor_parallel_size > 1:
                    with deepspeed.module_inject.layers.GatherReplacedLayerParams([param], model, enabled=True):
                        _handle_cuda_ipc(param, count, num_params, llm_mark)
                else:
                    with deepspeed.zero.GatheredParameters([param], enabled=strategy.args.zero_stage == 3):
                        _handle_cuda_ipc(param, count, num_params, llm_mark)

        if cache_reset_refs:
            ray.get(cache_reset_refs)
        torch.cuda.empty_cache()
        torch_dist_barrier_and_cuda_sync()


@ray.remote(num_gpus=1)
class ILRActorModelRayActor(BasePPORole):
    def init_model_from_pretrained(self, strategy: ILRDeepspeedStrategy, pretrain_llm1, pretrain_llm2, max_steps, vllm_engines_llm1, vllm_engines_llm2):
        args = strategy.args
        self.strategy_llm1 = strategy
        self.strategy_llm2 = copy.deepcopy(self.strategy_llm1)
        
        if getattr(args, "vllm_num_engines", 0) > 0:
            # To prevent hanging during NCCL synchronization of weights between DeepSpeed and vLLM.
            # see https://github.com/vllm-project/vllm/blob/c6b0a7d3ba03ca414be1174e9bd86a97191b7090/vllm/worker/worker_base.py#L445
            if getattr(args, "vllm_sync_backend", "nccl") == "nccl":
                os.environ["NCCL_CUMEM_ENABLE"] = "0"
                
        self.pretrain_llm1 = pretrain_llm1
        self.pretrain_llm2 = pretrain_llm2

        self.save_hf_ckpt = args.save_hf_ckpt
        self.disable_ds_ckpt = args.disable_ds_ckpt
        self.vllm_engines_llm1 = vllm_engines_llm1
        self.vllm_engines_llm2 = vllm_engines_llm2
        self.max_steps = max_steps

        self._setup_distributed(self.strategy_llm1)
        self._setup_distributed(self.strategy_llm2)

        actor_llm1 = ilr_Actor(
            pretrain_llm1,
            use_flash_attention_2=strategy.args.flash_attn,
            bf16=strategy.args.bf16,
            load_in_4bit=strategy.args.load_in_4bit,
            lora_rank=strategy.args.lora_rank,
            lora_alpha=strategy.args.lora_alpha,
            target_modules=strategy.args.target_modules,
            lora_dropout=strategy.args.lora_dropout,
            ds_config=strategy.get_ds_train_config(is_ilr_Actor=True),
            packing_samples=strategy.args.packing_samples,
            temperature=strategy.args.temperature,
            use_liger_kernel=strategy.args.use_liger_kernel,
        )
        strategy.print(actor_llm1)

        actor_llm2 = ilr_Actor(
            pretrain_llm2,
            use_flash_attention_2=strategy.args.flash_attn,
            bf16=strategy.args.bf16,
            load_in_4bit=strategy.args.load_in_4bit,
            lora_rank=strategy.args.lora_rank,
            lora_alpha=strategy.args.lora_alpha,
            target_modules=strategy.args.target_modules,
            lora_dropout=strategy.args.lora_dropout,
            ds_config=strategy.get_ds_train_config(is_ilr_Actor=True),
            packing_samples=strategy.args.packing_samples,
            temperature=strategy.args.temperature,
            use_liger_kernel=strategy.args.use_liger_kernel,
        )
        strategy.print(actor_llm2)

        # configure tokenizer
        self.tokenizer_llm1 = get_tokenizer(
            pretrain_llm1, actor_llm1.model, "left", self.strategy_llm1, use_fast=not self.strategy_llm1.args.disable_fast_tokenizer
        )
        self.tokenizer_llm2 = get_tokenizer(
            pretrain_llm2, actor_llm2.model, "left", self.strategy_llm2, use_fast=not self.strategy_llm2.args.disable_fast_tokenizer
        )

        if args.enable_ema:
            ema_model = Actor(
                pretrain,
                use_flash_attention_2=strategy.args.flash_attn,
                bf16=strategy.args.bf16,
                load_in_4bit=strategy.args.load_in_4bit,
                ds_config=strategy.get_ds_eval_config(offload=True),
                packing_samples=strategy.args.packing_samples,
            )
        else:
            ema_model = None

        # configure optimizer
        actor_optim_llm1 = self.strategy_llm1.create_optimizer(
            actor_llm1, lr=args.actor_learning_rate_llm1, betas=self.strategy_llm1.args.adam_betas, weight_decay=args.l2
        )
        actor_optim_llm2 = self.strategy_llm2.create_optimizer(
            actor_llm2, lr=args.actor_learning_rate_llm2, betas=self.strategy_llm2.args.adam_betas, weight_decay=args.l2
        )

        actor_scheduler_llm1 = get_scheduler(
            "cosine_with_min_lr",
            actor_optim_llm1,
            num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio),
            num_training_steps=max_steps,
            scheduler_specific_kwargs={"min_lr": args.actor_learning_rate_llm1 * 0.1},
        )
        actor_scheduler_llm2 = get_scheduler(
            "cosine_with_min_lr",
            actor_optim_llm2,
            num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio),
            num_training_steps=max_steps,
            scheduler_specific_kwargs={"min_lr": args.actor_learning_rate_llm2 * 0.1},
        )

        if args.gradient_checkpointing:
            actor_llm1.gradient_checkpointing_enable(
                gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant}
            )
            actor_llm2.gradient_checkpointing_enable(
                gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant}
            )

        # prepare models/optimizers...
        self.actor_llm1, self.actor_optim_llm1, self.actor_scheduler_llm1 = self.strategy_llm1.prepare(
            (actor_llm1, actor_optim_llm1, actor_scheduler_llm1),
            is_rlhf=True,
        )
        self.actor_llm2, self.actor_optim_llm2, self.actor_scheduler_llm2 = self.strategy_llm2.prepare(
            (actor_llm2, actor_optim_llm2, actor_scheduler_llm2),
            is_rlhf=True,
        )
        

        if ema_model:
            ema_model._offload = True
            self.ema_model = strategy.prepare(ema_model, is_rlhf=True)
        else:
            self.ema_model = None

        # load checkpoint
        self.consumed_samples = 0
        ckpt_path_llm1 = os.path.join(args.ckpt_path_llm1, "_actor")
        ckpt_path_llm2 = os.path.join(args.ckpt_path_llm2, "_actor")
        if args.load_checkpoint and os.path.exists(ckpt_path_llm1) and os.path.exists(ckpt_path_llm2):
            self.strategy_llm1.print(f"Loading the checkpoint: {ckpt_path_llm1}")
            _, states_llm1 = self.strategy_llm1.load_ckpt(self.actor_llm1.model, ckpt_path_llm1)
            self.strategy_llm2.print(f"Loading the checkpoint: {ckpt_path_llm2}")
            _, states_llm2 = self.strategy_llm2.load_ckpt(self.actor_llm2.model, ckpt_path_llm2)
            self.consumed_samples = min(states_llm1["consumed_samples"], states_llm2["consumed_samples"])
            strategy.print(f"consumed_samples: {self.consumed_samples}")

        # initial offload
        if strategy.args.deepspeed_enable_sleep:
            offload_deepspeed_states(self.actor_llm1.model)
            offload_deepspeed_states(self.actor_llm2.model)

        # configure Trainer
        self.trainer = ActorILRTrainer(
            self.strategy_llm1,
            self.strategy_llm2,
            self.actor_llm1,
            self.actor_llm2,
            ema_model=self.ema_model,
            actor_optim_llm1=self.actor_optim_llm1,
            actor_optim_llm2=self.actor_optim_llm2,
            actor_scheduler_llm1=self.actor_scheduler_llm1,
            actor_scheduler_llm2=self.actor_scheduler_llm2,
            micro_train_batch_size=args.micro_train_batch_size,
            tokenizer_llm1=self.tokenizer_llm1,
            tokenizer_llm2=self.tokenizer_llm2,
            eps_clip=args.eps_clip,
            ema_beta=args.ema_beta,
            vllm_engines_llm1=self.vllm_engines_llm1,
            vllm_engines_llm2=self.vllm_engines_llm2,
        )

    def fit(self, kl_ctl: float = 0, llm_mark=None):
        """Train actor model with the replay buffer."""
        if llm_mark == 'llm1':
            actor = self.actor_llm1
            replay_buffer = self.trainer.replay_buffer_llm1
        elif llm_mark == 'llm2':
            actor = self.actor_llm2
            replay_buffer = self.trainer.replay_buffer_llm2

        torch.cuda.empty_cache()
        actor.train()
        status = self.trainer.ppo_train(kl_ctl, llm_mark)
        replay_buffer.clear()
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        return status

    def save_model(self):
        args = self.strategy_llm1.args

        # save model checkpoint after fitting on only rank0
        self.strategy_llm1.save_model(
            self.ema_model if args.enable_ema else self.actor_llm1,
            self.tokenizer_llm1,
            args.save_path_llm1,
        )

        args = self.strategy_llm2.args

        # save model checkpoint after fitting on only rank0
        self.strategy_llm2.save_model(
            self.ema_model if args.enable_ema else self.actor_llm2,
            self.tokenizer_llm2,
            args.save_path_llm2,
        )

    def forward(
        self,
        sequences: torch.LongTensor,
        action_mask: Optional[Union[int, list[int]]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        packed_seq_lens=None,
        llm_mark=None,
    ) -> torch.Tensor:
        """Generates actor values."""
        if llm_mark == 'llm1':
            actor = self.actor_llm1
            strategy = self.strategy_llm1
        elif llm_mark == 'llm2':
            actor = self.actor_llm2
            strategy = self.strategy_llm2
        device = torch.cuda.current_device()
        actor.eval()
        with torch.no_grad():
            action_log_probs = actor(
                sequences.to(device),
                action_mask.to(device),
                attention_mask.to(device),
                ring_attn_group=strategy.ring_attn_group,
            )
        actor.train()  # reset model state
        return action_log_probs.to("cpu")

    def broadcast_to_vllm(self, llm_mark):
        self.trainer._broadcast_to_vllm(llm_mark)

    def get_consumed_samples(self):
        return self.consumed_samples

    def append(self, experiences_llm1: Experience, experiences_llm2: Experience):
        self.trainer.replay_buffer_llm1.append(experiences_llm1)
        self.trainer.replay_buffer_llm2.append(experiences_llm2)

    def reload_states(self, llm_mark):
        if llm_mark == 'llm1':
            reload_deepspeed_states(self.actor_llm1.model)
        elif llm_mark == 'llm2':
            reload_deepspeed_states(self.actor_llm2.model)

    def offload_states(self, llm_mark):
        if llm_mark == 'llm1':
            offload_deepspeed_states(self.actor_llm1.model)
        elif llm_mark == 'llm2':
            offload_deepspeed_states(self.actor_llm2.model)

    def save_checkpoint(self, tag, client_states, llm_mark):
        if llm_mark == 'llm1':
            args = self.strategy_llm1.args
            strategy = self.strategy_llm1
            actor = self.actor_llm1
            ckpt_path = args.ckpt_path_llm1
            tokenizer = self.tokenizer_llm1
        elif llm_mark == 'llm2':
            args = self.strategy_llm2.args
            strategy = self.strategy_llm2
            actor = self.actor_llm2
            ckpt_path = args.ckpt_path_llm2
            tokenizer = self.tokenizer_llm2
        
        print(actor)
        strategy.save_ckpt(
            actor.model,
            os.path.join(ckpt_path, "_actor"),
            tag,
            args.max_ckpt_num,
            args.max_ckpt_mem,
            client_states,
        )
        if self.save_hf_ckpt:
            save_path = os.path.join(ckpt_path, f"{tag}_hf")
            strategy.save_model(
                self.ema_model if args.enable_ema else actor,
                tokenizer,
                save_path,
            )
        # wait
        torch_dist_barrier_and_cuda_sync()
    
    def get_ability(self, llm_mark):
        if llm_mark == 'llm1':
            return self.ability_llm1
        elif llm_mark == 'llm2':
            return self.ability_llm2

    def eval_ability(self, eval_file, llms, tokenizer, llm_mark, step):
        if llm_mark == 'llm1':
            pretrain_path = self.pretrain_llm1
        elif llm_mark == 'llm2':
            pretrain_path = self.pretrain_llm2
        print(f"Eval {llm_mark}'s ability on {eval_file} on Global Step{step}.")
        examples = self.prepare_data(eval_file)
        print("=" * 50)
        print("data:", eval_file, " ,remain samples:", len(examples))
        if len(examples) > 0:
            print(examples[0])
        
        samples = []
        for example in examples:
            idx = example["idx"]

            # parse question and answer
            example["question"] = self.parse_question(example)
            if example["question"] == "":
                continue
            gt_cot, gt_ans = self.parse_ground_truth(example)
            example["gt_ans"] = gt_ans
            full_prompt = self.construct_prompt(example)

            if idx == 0:
                print(full_prompt)

            sample = {
                "idx": idx,
                "question": example["question"],
                "gt_cot": gt_cot,
                "gt": gt_ans,
                "prompt": full_prompt,
            }

            # add remain fields
            for key in [
                "level",
                "type",
                "unit",
                "solution_type",
                "choices",
                "solution",
                "ques_type",
                "ans_type",
                "answer_type",
                "dataset",
                "subfield",
                "filed",
                "theorem",
                "answer",
            ]:
                if key in example:
                    sample[key] = example[key]
            samples.append(sample)

        # repeat n times
        input_prompts = [
            sample["prompt"] for sample in samples for _ in range(1)
        ]

        remain_prompts = input_prompts
        remain_prompts = [(i, prompt) for i, prompt in enumerate(remain_prompts)]
        end_prompts = []
        max_func_call = 1
        stop_words = ["</s>", "<|im_end|>", "<|endoftext|>"]

        # start inference
        # measure time use
        start_time = time.time()
        for epoch in range(max_func_call):
            print("-" * 20, "Epoch", epoch)
            current_prompts = remain_prompts
            if len(current_prompts) == 0:
                break

            # get all outputs
            prompts = [item[1] for item in current_prompts]
            all_prompt_token_ids = tokenizer(
                prompts,
                add_special_tokens=False,
                max_length=1024,
                truncation=True,
            )["input_ids"]

            refs = []
            batch_size = (len(all_prompt_token_ids) + len(llms) - 1) // len(llms)
            for i, llm in enumerate(llms):
                prompt_token_ids = all_prompt_token_ids[i * batch_size : (i + 1) * batch_size]
                refs.append(
                    llm.add_requests.remote(0,
                    sampling_params=SamplingParams(
                        temperature=0,
                        top_p=1,
                        max_tokens=2048,
                        n=1,
                        stop=stop_words,
                        stop_token_ids=(
                            [151645, 151643]
                            if "qwen2" in pretrain_path.lower()
                            else None
                        ),
                    ), prompt_token_ids=prompt_token_ids)
                )
            ray.get(refs)

            # Retrieve and combine results from all outputs
            all_output_refs = []
            for i, llm in enumerate(llms):
                all_output_refs.append(llm.get_responses.remote(0))
            outputs = sum(ray.get(all_output_refs), [])


            # outputs = sorted(
            #     outputs, key=lambda x: int(x.request_id)
            # )  # sort outputs by request_id
            outputs = [output.outputs[0].text for output in outputs]


            assert len(outputs) == len(current_prompts)

            # process all outputs
            remain_prompts = []
            remain_codes = []
            for (i, query), output in zip(current_prompts, outputs):
                output = output.rstrip()
                query += output
                end_prompts.append((i, query))

        # unsolved samples
        print("Unsolved samples:", len(remain_prompts))
        end_prompts.extend(remain_prompts)
        # sort by idx
        end_prompts = sorted(end_prompts, key=lambda x: x[0])

        # remove input_prompt from end_prompt
        codes = []
        assert len(input_prompts) == len(end_prompts)
        for i in range(len(input_prompts)):
            _, end_prompt = end_prompts[i]
            code = end_prompt.split(input_prompts[i])[-1].strip()
            for stop_word in stop_words:
                if stop_word in code:
                    code = code.split(stop_word)[0].strip()
            codes.append(code)

        # extract preds
        results = [
            self.run_execute(code) for code in codes
        ]
        time_use = time.time() - start_time

        # put results back to examples
        all_samples = []
        for i, sample in enumerate(samples):
            code = codes[i : (i + 1)]
            result = results[i : (i + 1)]
            preds = [item[0] for item in result]
            reports = [item[1] for item in result]
            for j in range(len(preds)):
                if sample["gt"] in ["A", "B", "C", "D", "E"] and preds[j] not in [
                    "A",
                    "B",
                    "C",
                    "D",
                    "E",
                ]:
                    preds[j] = self.choice_answer_clean(code[j])
                elif self.is_multi_choice(sample["gt"]) and not self.is_multi_choice(preds[j]):
                    # remove any non-choice char
                    preds[j] = "".join(
                        [c for c in preds[j] if c in ["A", "B", "C", "D", "E"]]
                    )

            sample.pop("prompt")
            sample.update({"code": code, "pred": preds, "report": reports})
            all_samples.append(sample)
        
        if llm_mark == 'llm1':
            self.ability_llm1 = self.evaluate(samples=all_samples) / 100
            print(f"{llm_mark}'s ability: {self.ability_llm1}")
        elif llm_mark == 'llm2':
            self.ability_llm2 = self.evaluate(samples=all_samples) / 100
            print(f"{llm_mark}'s ability: {self.ability_llm2}")
        
        

    def load_jsonl(self, file):
        with open(file, "r", encoding="utf-8") as f:
            for line in f:
                try:
                    yield json.loads(line)
                except:
                    print("Error in loading:", line)
                    exit()

    def prepare_data(self, data_file):
        examples = list(self.load_jsonl(data_file))
        # add 'idx' in the first column
        if "idx" not in examples[0]:
            examples = [{"idx": i, **example} for i, example in enumerate(examples)]
        examples = sorted(examples, key=lambda x: x["idx"])
        return examples
    
    def parse_question(self, example):
        question = ""
        for key in ["question", "problem", "Question", "input"]:
            if key in example:
                question = example[key]
                break
        # assert question != ""
        # Yes or No question
        _, gt_ans = self.parse_ground_truth(example)
        if isinstance(gt_ans, str):
            gt_lower = gt_ans.lower()
            if gt_lower in ["true", "false"]:
                question += " (True or False)"
            if gt_lower in ["yes", "no"]:
                question += " (Yes or No)"
        return question.strip()
    
    def parse_ground_truth(self, example):
        # parse ground truth
        gt_cot = example["solution"]
        gt_ans = self.extract_answer(gt_cot)
        # post process
        gt_cot = str(gt_cot).strip()
        gt_ans = self.strip_string(gt_ans, skip_unit=False)
        return gt_cot, gt_ans
    
    def extract_answer(self, pred_str, use_last_number=True):
        pred_str = pred_str.replace("\u043a\u0438", "")

        if "final answer is $" in pred_str and "$. I hope" in pred_str:
            # minerva_math
            tmp = pred_str.split("final answer is $", 1)[1]
            pred = tmp.split("$. I hope", 1)[0].strip()
        elif "boxed" in pred_str:
            ans = pred_str.split("boxed")[-1]
            if len(ans) == 0:
                return ""
            elif ans[0] == "{":
                stack = 1
                a = ""
                for c in ans[1:]:
                    if c == "{":
                        stack += 1
                        a += c
                    elif c == "}":
                        stack -= 1
                        if stack == 0:
                            break
                        a += c
                    else:
                        a += c
            else:
                a = ans.split("$")[0].strip()
            pred = a
        elif "he answer is" in pred_str:
            pred = pred_str.split("he answer is")[-1].strip()
        elif "final answer is" in pred_str:
            pred = pred_str.split("final answer is")[-1].strip()
        elif "答案是" in pred_str:
            # Handle Chinese few-shot multiple choice problem answer extraction
            pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
        else:  # use the last number
            if use_last_number:
                pattern = "-?\d*\.?\d+"
                pred = re.findall(pattern, pred_str.replace(",", ""))
                if len(pred) >= 1:
                    pred = pred[-1]
                else:
                    pred = ""
            else:
                pred = ""

        # multiple line
        # pred = pred.split("\n")[0]
        pred = re.sub(r"\n\s*", "", pred)
        if pred != "" and pred[0] == ":":
            pred = pred[1:]
        if pred != "" and pred[-1] == ".":
            pred = pred[:-1]
        if pred != "" and pred[-1] == "/":
            pred = pred[:-1]
        pred = self.strip_string(pred, skip_unit=False)
        return pred
    
    def strip_string(self, string, skip_unit=False):
        unit_texts = [
            "east",
            "degree",
            "mph",
            "kmph",
            "ft",
            "m sqaure",
            " m east",
            "sq m",
            "deg",
            "mile",
            "q .",
            "monkey",
            "prime",
            "ratio",
            "profit of rs",
            "rd",
            "o",
            "gm",
            "p . m",
            "lb",
            "tile",
            "per",
            "dm",
            "lt",
            "gain",
            "ab",
            "way",
            "west",
            "a .",
            "b .",
            "c .",
            "d .",
            "e .",
            "f .",
            "g .",
            "h .",
            "t",
            "a",
            "h",
            "no change",
            "men",
            "soldier",
            "pie",
            "bc",
            "excess",
            "st",
            "inches",
            "noon",
            "percent",
            "by",
            "gal",
            "kmh",
            "c",
            "acre",
            "rise",
            "a . m",
            "th",
            "π r 2",
            "sq",
            "mark",
            "l",
            "toy",
            "coin",
            "sq . m",
            "gallon",
            "° f",
            "profit",
            "minw",
            "yr",
            "women",
            "feet",
            "am",
            "pm",
            "hr",
            "cu cm",
            "square",
            "v â € ™",
            "are",
            "rupee",
            "rounds",
            "cubic",
            "cc",
            "mtr",
            "s",
            "ohm",
            "number",
            "kmph",
            "day",
            "hour",
            "minute",
            "min",
            "second",
            "man",
            "woman",
            "sec",
            "cube",
            "mt",
            "sq inch",
            "mp",
            "∏ cm ³",
            "hectare",
            "more",
            "sec",
            "unit",
            "cu . m",
            "cm 2",
            "rs .",
            "rs",
            "kg",
            "g",
            "month",
            "km",
            "m",
            "cm",
            "mm",
            "apple",
            "liter",
            "loss",
            "yard",
            "pure",
            "year",
            "increase",
            "decrease",
            "d",
            "less",
            "Surface",
            "litre",
            "pi sq m",
            "s .",
            "metre",
            "meter",
            "inch",
        ]

        def convert_word_number(text):
            try:
                text = str(w2n.word_to_num(text))
            except:
                pass
            return text
        
        def _fix_sqrt(string):
            _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
            return _string
        
        def _fix_fracs(string):
            substrs = string.split("\\frac")
            new_str = substrs[0]
            if len(substrs) > 1:
                substrs = substrs[1:]
                for substr in substrs:
                    new_str += "\\frac"
                    if len(substr) > 0 and substr[0] == "{":
                        new_str += substr
                    else:
                        try:
                            assert len(substr) >= 2
                        except:
                            return string
                        a = substr[0]
                        b = substr[1]
                        if b != "{":
                            if len(substr) > 2:
                                post_substr = substr[2:]
                                new_str += "{" + a + "}{" + b + "}" + post_substr
                            else:
                                new_str += "{" + a + "}{" + b + "}"
                        else:
                            if len(substr) > 2:
                                post_substr = substr[2:]
                                new_str += "{" + a + "}" + b + post_substr
                            else:
                                new_str += "{" + a + "}" + b
            string = new_str
            return string
        def _fix_a_slash_b(string):
            if len(string.split("/")) != 2:
                return string
            a = string.split("/")[0]
            b = string.split("/")[1]
            try:
                if "sqrt" not in a:
                    a = int(a)
                if "sqrt" not in b:
                    b = int(b)
                assert string == "{}/{}".format(a, b)
                new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
                return new_string
            except:
                return string


        unit_texts.extend([t + "s" for t in unit_texts])

        string = str(string).strip()
        # linebreaks
        string = string.replace("\n", "")

        # right "."
        string = string.rstrip(".")

        # remove inverse spaces
        # replace \\ with \
        string = string.replace("\\!", "")
        # string = string.replace("\\ ", "")
        # string = string.replace("\\\\", "\\")

        # matrix
        string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)
        string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string)
        string = string.replace("bmatrix", "pmatrix")

        # replace tfrac and dfrac with frac
        string = string.replace("tfrac", "frac")
        string = string.replace("dfrac", "frac")
        string = (
            string.replace("\\neq", "\\ne")
            .replace("\\leq", "\\le")
            .replace("\\geq", "\\ge")
        )

        # remove \left and \right
        string = string.replace("\\left", "")
        string = string.replace("\\right", "")
        string = string.replace("\\{", "{")
        string = string.replace("\\}", "}")

        # Remove unit: miles, dollars if after is not none
        _string = re.sub(r"\\text{.*?}$", "", string).strip()
        if _string != "" and _string != string:
            # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
            string = _string

        if not skip_unit:
            # Remove unit: texts
            for _ in range(2):
                for unit_text in unit_texts:
                    # use regex, the prefix should be either the start of the string or a non-alphanumeric character
                    # the suffix should be either the end of the string or a non-alphanumeric character
                    _string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
                    if _string != "":
                        string = _string

        # Remove circ (degrees)
        string = string.replace("^{\\circ}", "")
        string = string.replace("^\\circ", "")

        # remove dollar signs
        string = string.replace("\\$", "")
        string = string.replace("$", "")
        string = string.replace("\\(", "").replace("\\)", "")

        # convert word number to digit
        string = convert_word_number(string)

        # replace "\\text{...}" to "..."
        string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
        for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]:
            string = string.replace(key, "")
        string = string.replace("\\emptyset", r"{}")
        string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")

        # remove percentage
        string = string.replace("\\%", "")
        string = string.replace("\%", "")
        string = string.replace("%", "")

        # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
        string = string.replace(" .", " 0.")
        string = string.replace("{.", "{0.")

        # cdot
        # string = string.replace("\\cdot", "")
        if (
            string.startswith("{")
            and string.endswith("}")
            and string.isalnum()
            or string.startswith("(")
            and string.endswith(")")
            and string.isalnum()
            or string.startswith("[")
            and string.endswith("]")
            and string.isalnum()
        ):
            string = string[1:-1]

        # inf
        string = string.replace("infinity", "\\infty")
        if "\\infty" not in string:
            string = string.replace("inf", "\\infty")
        string = string.replace("+\\inity", "\\infty")

        # and
        string = string.replace("and", "")
        string = string.replace("\\mathbf", "")

        # use regex to remove \mbox{...}
        string = re.sub(r"\\mbox{.*?}", "", string)

        # quote
        string.replace("'", "")
        string.replace('"', "")

        # i, j
        if "j" in string and "i" not in string:
            string = string.replace("j", "i")

        # replace a.000b where b is not number or b is end, with ab, use regex
        string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string)
        string = re.sub(r"(\d+)\.0*$", r"\1", string)

        # if empty, return empty string
        if len(string) == 0:
            return string
        if string[0] == ".":
            string = "0" + string

        # to consider: get rid of e.g. "k = " or "q = " at beginning
        if len(string.split("=")) == 2:
            if len(string.split("=")[0]) <= 2:
                string = string.split("=")[1]

        string = _fix_sqrt(string)
        string = string.replace(" ", "")

        # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
        string = _fix_fracs(string)

        # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
        string = _fix_a_slash_b(string)

        return string

    def construct_prompt(self,example):
        prompt_temp = (
            "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
            "<|im_start|>user\n{input}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n"
            "<|im_start|>assistant\n",
            "{output}",
            "\n\n",
        )
        splitter = prompt_temp[2]
        input_template, output_template, splitter = (
            prompt_temp[0],
            prompt_temp[1],
            prompt_temp[2],
        )
        context = input_template.format(input=example["question"])
        full_prompt = context
        return full_prompt.strip(" ")  # important!

    def run_execute(self, result, execute=False):
        if not result or result == "error":
            return None, None
        report = None
        prediction = self.extract_answer(result)
        prediction = self.strip_string(prediction, skip_unit=False)
        return prediction, report

    def choice_answer_clean(self, pred):
        direct_answer_trigger_for_fewshot = ("choice is", "answer is")
        pred = pred.strip("\n")

        # Determine if this is ICL, if so, use \n\n to split the first chunk.
        ICL = False
        for trigger in direct_answer_trigger_for_fewshot:
            if pred.count(trigger) > 1:
                ICL = True
        if ICL:
            pred = pred.split("\n\n")[0]

        # Split the trigger to find the answer.
        preds = re.split("|".join(direct_answer_trigger_for_fewshot), pred)
        if len(preds) > 1:
            answer_flag = True
            pred = preds[-1]
        else:
            answer_flag = False

        pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")

        # Clean the answer based on the dataset
        tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
        if tmp:
            pred = tmp
        else:
            pred = [pred.strip().strip(".")]

        if len(pred) == 0:
            pred = ""
        else:
            if answer_flag:
                # choose the first element in list ...
                pred = pred[0]
            else:
                # choose the last e
                pred = pred[-1]

        # Remove the period at the end, again!
        pred = pred.rstrip(".").rstrip("/")

        return pred

    def is_multi_choice(self, answer):
        for c in answer:
            if c not in ["A", "B", "C", "D", "E"]:
                return False
        return True

    def evaluate(self, samples, execute=False):
        def math_equal_process(param):
            def parse_digits(num):
                num = regex.sub(",", "", str(num))
                try:
                    return float(num)
                except:
                    if num.endswith("%"):
                        num = num[:-1]
                        if num.endswith("\\"):
                            num = num[:-1]
                        try:
                            return float(num) / 100
                        except:
                            pass
                return None
            
            def numeric_equal(prediction: float, reference: float):
                # Note that relative tolerance has significant impact
                # on the result of the synthesized GSM-Hard dataset
                # if reference.is_integer():
                #     return isclose(reference, round(prediction), abs_tol=1e-4)
                # else:
                # prediction = round(prediction, len(str(reference).split(".")[-1]))
                return isclose(reference, prediction, rel_tol=1e-4)

            def symbolic_equal(a, b):
                def _parse(s):
                    for f in [parse_latex, parse_expr, latex2sympy]:
                        try:
                            return f(s.replace("\\\\", "\\"))
                        except:
                            try:
                                return f(s)
                            except:
                                pass
                    return s

                a = _parse(a)
                b = _parse(b)

                # direct equal
                try:
                    if str(a) == str(b) or a == b:
                        return True
                except:
                    pass

                # simplify equal
                try:
                    if a.equals(b) or simplify(a - b) == 0:
                        return True
                except:
                    pass

                # equation equal
                try:
                    if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
                        return True
                except:
                    pass

                try:
                    if numeric_equal(float(N(a)), float(N(b))):
                        return True
                except:
                    pass

                # matrix
                try:
                    # if a and b are matrix
                    if a.shape == b.shape:
                        _a = a.applyfunc(lambda x: round(x, 3))
                        _b = b.applyfunc(lambda x: round(x, 3))
                        if _a.equals(_b):
                            return True
                except:
                    pass

                return False

            def choice_answer_clean(pred: str):
                pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
                # Clean the answer based on the dataset
                tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
                if tmp:
                    pred = tmp
                else:
                    pred = [pred.strip().strip(".")]
                pred = pred[-1]
                # Remove the period at the end, again!
                pred = pred.rstrip(".").rstrip("/")
                return pred

            def is_digit(num):
                # paired with parse_digits
                return parse_digits(num) is not None

            def str_to_pmatrix(input_str):
                input_str = input_str.strip()
                matrix_str = re.findall(r"\{.*,.*\}", input_str)
                pmatrix_list = []

                for m in matrix_str:
                    m = m.strip("{}")
                    pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}"
                    pmatrix_list.append(pmatrix)

                return ", ".join(pmatrix_list)

            def math_equal(
                prediction: Union[bool, float, str],
                reference: Union[float, str],
                include_percentage: bool = True,
                is_close: bool = True,
                timeout: bool = False,
            ) -> bool:
                """
                Exact match of math if and only if:
                1. numerical equal: both can convert to float and are equal
                2. symbolic equal: both can convert to sympy expression and are equal
                """
                # print("Judge:", prediction, reference)
                if prediction is None or reference is None:
                    return False
                if str(prediction.strip().lower()) == str(reference.strip().lower()):
                    return True
                if (
                    reference in ["A", "B", "C", "D", "E"]
                    and choice_answer_clean(prediction) == reference
                ):
                    return True

                try:  # 1. numerical equal
                    if is_digit(prediction) and is_digit(reference):
                        prediction = parse_digits(prediction)
                        reference = parse_digits(reference)
                        # number questions
                        if include_percentage:
                            gt_result = [reference / 100, reference, reference * 100]
                        else:
                            gt_result = [reference]
                        for item in gt_result:
                            try:
                                if is_close:
                                    if numeric_equal(prediction, item):
                                        return True
                                else:
                                    if item == prediction:
                                        return True
                            except Exception:
                                continue
                        return False
                except:
                    pass

                if not prediction and prediction not in [0, False]:
                    return False

                # 2. symbolic equal
                reference = str(reference).strip()
                prediction = str(prediction).strip()

                ## pmatrix (amps)
                if "pmatrix" in prediction and not "pmatrix" in reference:
                    reference = str_to_pmatrix(reference)

                ## deal with [], (), {}
                pred_str, ref_str = prediction, reference
                if (
                    prediction.startswith("[")
                    and prediction.endswith("]")
                    and not reference.startswith("(")
                ) or (
                    prediction.startswith("(")
                    and prediction.endswith(")")
                    and not reference.startswith("[")
                ):
                    pred_str = pred_str.strip("[]()")
                    ref_str = ref_str.strip("[]()")
                for s in ["{", "}", "(", ")"]:
                    ref_str = ref_str.replace(s, "")
                    pred_str = pred_str.replace(s, "")
                if pred_str.lower() == ref_str.lower():
                    return True

                ## [a, b] vs. [c, d], return a==c and b==d
                if (
                    regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
                    and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
                ):
                    pred_parts = prediction[1:-1].split(",")
                    ref_parts = reference[1:-1].split(",")
                    if len(pred_parts) == len(ref_parts):
                        if all(
                            [
                                math_equal(
                                    pred_parts[i], ref_parts[i], include_percentage, is_close
                                )
                                for i in range(len(pred_parts))
                            ]
                        ):
                            return True
                if (
                    (
                        prediction.startswith("\\begin{pmatrix}")
                        or prediction.startswith("\\begin{bmatrix}")
                    )
                    and (
                        prediction.endswith("\\end{pmatrix}")
                        or prediction.endswith("\\end{bmatrix}")
                    )
                    and (
                        reference.startswith("\\begin{pmatrix}")
                        or reference.startswith("\\begin{bmatrix}")
                    )
                    and (
                        reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
                    )
                ):
                    pred_lines = [
                        line.strip()
                        for line in prediction[
                            len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
                        ].split("\\\\")
                        if line.strip()
                    ]
                    ref_lines = [
                        line.strip()
                        for line in reference[
                            len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
                        ].split("\\\\")
                        if line.strip()
                    ]
                    matched = True
                    if len(pred_lines) == len(ref_lines):
                        for pred_line, ref_line in zip(pred_lines, ref_lines):
                            pred_parts = pred_line.split("&")
                            ref_parts = ref_line.split("&")
                            if len(pred_parts) == len(ref_parts):
                                if not all(
                                    [
                                        math_equal(
                                            pred_parts[i],
                                            ref_parts[i],
                                            include_percentage,
                                            is_close,
                                        )
                                        for i in range(len(pred_parts))
                                    ]
                                ):
                                    matched = False
                                    break
                            else:
                                matched = False
                            if not matched:
                                break
                    else:
                        matched = False
                    if matched:
                        return True

                if prediction.count("=") == 1 and reference.count("=") == 1:
                    pred = prediction.split("=")
                    pred = f"{pred[0].strip()} - ({pred[1].strip()})"
                    ref = reference.split("=")
                    ref = f"{ref[0].strip()} - ({ref[1].strip()})"
                    if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
                        return True
                elif (
                    prediction.count("=") == 1
                    and len(prediction.split("=")[0].strip()) <= 2
                    and "=" not in reference
                ):
                    if math_equal(
                        prediction.split("=")[1], reference, include_percentage, is_close
                    ):
                        return True
                elif (
                    reference.count("=") == 1
                    and len(reference.split("=")[0].strip()) <= 2
                    and "=" not in prediction
                ):
                    if math_equal(
                        prediction, reference.split("=")[1], include_percentage, is_close
                    ):
                        return True

                # symbolic equal with sympy
                if symbolic_equal(prediction, reference):
                    return True

                return False

            return math_equal(param[-2], param[-1])

        if 'idx' in samples[0]:
            samples = {sample['idx']: sample for sample in samples}.values()
            samples = sorted(samples, key=lambda x: x['idx']) 
        else:
            samples = [dict(idx=idx, **sample) for idx, sample in enumerate(samples)]
        
        # parse gt
        for sample in samples:
            sample['gt_cot'], sample['gt'] = self.parse_ground_truth(sample)
        params = [(idx, pred, sample['gt']) for idx, sample in enumerate(samples) for pred in sample['pred']]

        scores = []
        timeout_cnt = 0 
        with tqdm(total=len(params), desc="Evaluate") as progress_bar:
            for param in params:
                result = math_equal_process(param)
                scores.append(result)
                progress_bar.update(1)


        idx = 0
        score_mat = []
        for sample in samples:
            sample['score'] = scores[idx: idx+len(sample['pred'])]
            assert len(sample['score']) == len(sample['pred'])
            score_mat.append(sample['score'])
            idx += len(sample['pred'])

        max_len = max([len(s) for s in score_mat])

        for i, s in enumerate(score_mat):
            if len(s) < max_len:
                score_mat[i] = s + [s[-1]] * (max_len - len(s)) # pad

        # output mean of each column of scores
        col_means= np.array(score_mat).mean(axis=0)
        mean_score = list(np.round(col_means * 100, decimals=1))
        return mean_score[0]