import os
import sys
from datetime import datetime
from easydict import EasyDict as edict
from functools import partial

from ares.configs import RoleConfig, TokenizerConfig
from ares.configs.role_config import GenerativeRMRoleConfig
from ares.configs.trainer_config import PPOTrainerConfig
from ares.configs.experience_maker_config import UnifiedExperienceMakerConfig
from ares.models.megatron_ppo_loss import PolicyLoss, ValueLoss, GSPOPolicyLoss
from ares.models.loss import gather_logprobs_dist
from ares.models.megatron_ppo_loss import gather_values_dist, get_last_token_rm_score
from ares.trainers.components.grader.alignment.math_model_eval_grader import MathModelEvalGrader
from ares.trainers.components.grader.alignment.ifeval_rule_grader import IFEvalRuleGrader
from ares.utils.utils import get_tokenizer

FILE_PATH = os.path.dirname(os.path.realpath(__file__))

layout = edict(
    generator=edict(dp=2, tp=8, pp=1, ep=1),  # 8 + elastic(256)
    actor=edict(dp=32, tp=8, pp=4, ep=1),  # 256
    sft=edict(dp=32, tp=8, pp=4, ep=1),  # co-exist with actor
    generator_rm=edict(dp=4, tp=8, pp=1),  # 32
)

num_unfrozen_layers = None

engine_name = "megatron_engine"
dtype = 'bf16'  # choices = ['fp16', 'bf16', 'fp32']

# Model and path configuration
model_tag = 'exp_model'
sft_model_path = "/path/to/sft/model/checkpoint"
suffix = f"from_{model_tag}"
output_dir = f'/path/to/output/checkpoint/{suffix}'
log_dir = f"/path/to/log/directory/{suffix}"

data_dir = "/path/to/training/data"

generative_reward_model_path = "/path/to/reward/model"
generative_reward_model_mega_ckpt = "/path/to/reward/model/checkpoint"

# Engine config file paths
vllm_cfg_file_path = f'{FILE_PATH}/vllm_config.yaml'
vllm_genrm_cfg_file_path = f'{FILE_PATH}/vllm_config_genrm.yaml'
actor_megatron_cfg_file_path = f'{FILE_PATH}/megatron_config_actor.yaml'
critic_megatron_cfg_file_path = f'{FILE_PATH}/megatron_config_critic.yaml'
rm_megatron_cfg_file_path = f'{FILE_PATH}/megatron_config_rm_moe6b.yaml'

tokenizer_config = TokenizerConfig(
    tokenizer_path=sft_model_path,
    padding_side="left",
    truncation_side="right",
    sep_token="<sep>",
    trust_remote_code=True,
    force_pad_token_as_eos_token=False,
)
gen_rm_tokenizer_config = TokenizerConfig(
    tokenizer_path=generative_reward_model_path,
    padding_side="left",
    truncation_side="right",
    sep_token="<sep>",
    trust_remote_code=True,
    force_pad_token_as_eos_token=False,
)
tokenizer = get_tokenizer(tokenizer_config)

# Training parameters
max_prompt_len = 2048
max_new_tokens = 8192 * 4 - max_prompt_len
max_seq_len = 8192 * 4
sample_n = 16

generate_kwargs = {
    f"sample{i}": {
        "max_tokens": max_new_tokens,
        "min_tokens": 1,
        "do_sample": True,
        "temperature": 1.0,
        "top_p": 1.0,
        "top_k": -1,
        "n": 1,
    } for i in range(sample_n)
}

eval_generate_kwargs = {
    "sample": {"max_tokens": 8192 * 4, "min_tokens": 1, "do_sample": True, "temperature": 1.0, "top_p": 0.95,
               "top_k": 10, "n": 1, }
}

enable_temperature = False
shuffle_replay_buffer = False
rollout_batch_size = 256
loss_agg_mode = "seq-mean-token-mean"
num_samples_per_prompt = sum(v.get("n", 1) for v in generate_kwargs.values())
replay_buffer_size = rollout_batch_size * num_samples_per_prompt
micro_batch_size = 1  # micro_batch_size for actor & critic & rm & sft training or inference
global_batch_size = 512  # global_batch_size for actor & critic & rm & sft training or inference
mini_rollout_batch_size = None
ppo_epochs = 1
num_episodes = 100
nums_eval_selected = 960  # number of prompts for evaluation
init_kl_coef = 0.0
pg_cliprange = [0.0003, 0.0004]
value_cliprange = 0.4
log_interval = 1  # interval for tensorboard
eval_interval = 4
ckpt_interval = 4
warmup_rounds = 0  # must be smaller than eval_interval
add_kl_reward = False
adv_norm_type = "none"
gamma = 1.0
lambd = 1.0
ip_config_path = "/path/to/ip/config.txt"

open_online_filter = False
pass_rate_range = None
# pass_rate_range = [0.01, 0.99]  # filter all wrong, all correct, wrong once
prompt_gen_multiplier = 3
max_gen_times = 10

dump_token_level_log = False

config = edict(
    train=PPOTrainerConfig(
        name='PPOTrainer',
        data_dir=data_dir,
        num_prompt_processes=40,
        log_dir=f"{log_dir}/log",  # tensorboard differs from before, previously took results from all dp for plotting, now plots separately by dp for some actor or critic metrics
        output_dir=output_dir,
        ckpt_dir=f"{output_dir}/ckpt",
        ckpt_interval=ckpt_interval,
        eval_interval=eval_interval,
        log_interval=log_interval,
        micro_batch_size=micro_batch_size,
        global_batch_size=global_batch_size,
        default_generate_role="elastic",
        dump_token_level_log=dump_token_level_log,
        dump_token_level_log_ratio=0.1,
        seed=42,
        epochs=1,
        shuffle_train=True,
        init_kl_coef=init_kl_coef,
        kl_target=None,
        kl_horizon=500000,
        ptx_coef=0.0,
        rollout_batch_size=rollout_batch_size,
        mini_rollout_batch_size=None,
        ppo_epochs=ppo_epochs,
        num_episodes=num_episodes,
        warmup_rounds=warmup_rounds,  # train critic first, then train together
        max_prompt_len=max_prompt_len,
        max_offline_data_len=max_new_tokens,
        max_seq_len=max_seq_len,
        nums_eval_selected=nums_eval_selected,
        adv_norm_type=adv_norm_type,
        tokenizer=tokenizer,
        skip_kl_spike_step=True,
        kl_spike_window_size=8,
        kl_variance_threshold=0.3,
        open_online_filter=open_online_filter,
        prompt_gen_multiplier=prompt_gen_multiplier,
        shuffle_replay_buffer=shuffle_replay_buffer,
        max_gen_times=max_gen_times,
        eval_step_0=True,
        experience_maker_config=UnifiedExperienceMakerConfig(
            advantage_estimator="grpo",
            ignore_useless_data=True,
            gen_kwargs=generate_kwargs,
            eval_gen_kwargs=eval_generate_kwargs,
            add_kl_reward=add_kl_reward,
            pass_rate_range=pass_rate_range,  # turn on when enable open_online_filter
            gamma=gamma,
            lambd=lambd,
            enable_overlength_penalty=False,
            use_truncated_action_mask=False,
            enable_thinking_format=True,
            thinking_pattern=r'<think>(.+?)</think>',
            grader_dict=edict(
                math=edict(grader_class=MathModelEvalGrader, tokenizer=get_tokenizer(gen_rm_tokenizer_config)),
            ),
        ),

    ),
    roles=edict(
        generator=RoleConfig(
            name="generator",
            layout=layout.generator,
            dtype=dtype,
            engine_name='vllm_engine',
            engine_version="063",
            engine_cfg_file_path=vllm_cfg_file_path,
            micro_batch_size=1,
            global_batch_size=rollout_batch_size,
            pretrained=sft_model_path,
            model_name="llm",
            with_optimizer=False,
            max_seq_len=max_seq_len,
            max_prompt_len=max_prompt_len,
            max_new_tokens=max_new_tokens,
            tokenizer_config=tokenizer_config,
            enable_offload_others=True,
            load_format="dummy",
            uniq_seed_for_rank=True,
        ),
        actor=RoleConfig(
            name="actor",
            dtype=dtype,
            layout=layout.actor,
            with_optimizer=True,
            model_name="gptmodel",
            engine_name=engine_name,
            colocated_roles=['actor', 'generator', 'sft'],
            elastic_roles=['generator'],
            enable_offload_others=True,
            engine_cfg_file_path=actor_megatron_cfg_file_path,
            micro_batch_size=micro_batch_size,
            global_batch_size=global_batch_size,
            max_seq_len=max_seq_len,
            pretrained=sft_model_path,
            tokenizer_config=tokenizer_config,
            save_model_only=False,
            loss_func=GSPOPolicyLoss(
                cliprange=pg_cliprange,
                clip_ratio_c=None,
                kl_loss_coef=0.0,
                kl_loss_type="k2",
                nll_loss_coef=0.0,
                ignore_useless_data=True,
                loss_agg_mode=loss_agg_mode,
                mtp_loss_coef=0.00,
                enable_temperature=enable_temperature,
                max_tokens=max_new_tokens,
                dump_token_level_log=dump_token_level_log
            ),
            post_func=gather_logprobs_dist,
            check_oom=False,
            num_unfrozen_layers=None,
        ),

        sft=RoleConfig(
            name="sft",
            with_optimizer=False,
            max_seq_len=max_seq_len,
            layout=layout.sft,
            dtype=dtype,
            engine_name=engine_name,
            colocated_roles=['actor', 'generator', 'sft'],
            elastic_roles=['generator'],
            enable_offload_others=True,
            engine_cfg_file_path=actor_megatron_cfg_file_path,
            micro_batch_size=micro_batch_size,
            global_batch_size=global_batch_size,
            pretrained=sft_model_path,
            model_name="gptmodel",
            tokenizer_config=tokenizer_config,
            post_func=gather_logprobs_dist,
            check_oom=False
        ),

        generator_rm=GenerativeRMRoleConfig(
            name="generator_rm",
            dtype=dtype,
            layout=layout.generator_rm,
            engine_name='vllm_engine',
            engine_cfg_file_path=vllm_genrm_cfg_file_path,
            micro_batch_size=micro_batch_size,
            global_batch_size=global_batch_size,
            pretrained=generative_reward_model_path,
            model_name="llm",
            engine_version="063",
            empty_unused_memory_level=1,
            with_optimizer=False,
            max_seq_len=max_seq_len,
            max_prompt_len=max_seq_len - 1000,
            max_new_tokens=1000,
            tokenizer_config=gen_rm_tokenizer_config,
            tokenize_kwargs=edict(truncation=False, padding=True, return_tensors="pt", padding_side="left"),
            load_format="dummy",
            mega_ckpt_dir=generative_reward_model_mega_ckpt,
            prompt_template="""
            You are an experienced mathematics grader, your task is evaluating whether a student's answer matches the reference answer based on specific comparison rules. **Do not attempt to assess the correctness** of either the student's answer or the reference answer independently. **Assume that the Reference Answer is correct** and focus only on determining if the Student Answer aligns with the Reference Answer according to the rules provided below.
**1. Input:**
    - **Question:** "{{question}}"
    - **Student's Answer:** "{{answer}}"
    - **Reference Answer:** "{{answer_reference}}"

**2. Comparison Rules:**

   **A. Variable and Constant Usage:**
      - **Allowed Variables/Constants:**
         - Identify all variables and constants present in the **Question**.
         - Allowed variables typically include symbols like `x`, `y`, `z`, etc.
         - Allowed constants may include mathematical constants such as `π`, `e`, or any other specific constants mentioned in the question.
      - **Validation:**
         - Examine the **Student Answer** for any variables or constants.
         - If the **Student Answer** contains variables or constants **not present** in the **Question**, the answer is considered invalid.

   **B. Numerical Formats Comparison:**
      - **Identifying Numerical Formats:**
         - Determine if each answer (Student and Reference) is a fraction, a decimal, scientific notation, or another format.

      - **Fractions:**
         - If both answers are fractions, they must be **exactly identical**. No approximation, simplification, or alternative representations are allowed.

      - **Decimals:**
         - **Rounding:**
            - Identify the number of decimal places in the **Reference Answer**.
            - **Round the Student Answer** to match the number of decimal places in the **Reference Answer** using standard rounding rules (round half up).
         - **Comparison:**
            - After rounding, the decimal numbers must be **identical**.

      - **Scientific Notation:**
         - **Exact Matching:**
            - If **either** answer is in scientific notation (e.g., `1.23e+4`), then **both** answers must be in scientific notation.
            - **Every digit**, including those in the mantissa and the exponent, must **exactly match**.
            - **No approximations, rounding, or differences** in digit representation are allowed.
            - **Formatting Consistency:** Ensure that both answers use the same case for the exponent indicator (`e` vs. `E`) and the same format for positive/negative exponents (e.g., `+4` vs. `4`).

      - **Mixed Formats:**
         - If one answer is a fraction and the other is a decimal:
            - Convert the fraction to a decimal.
            - Round the converted fraction to match the precision of the **Reference Answer**.
            - Compare the rounded Student Answer decimal with the Reference Answer decimal for an **exact match**.

      - **Other Formats:**
         - For non-numeric answers or different data types (e.g., one is a number and the other is text), perform an **exact match** comparison, considering all characters, case sensitivity, punctuation, and formatting.

   **C. Multiple Possible Correct Answers:**
      - **Identifying Multiple Options:**
         - Inspect the **Reference Answer** to determine if it contains multiple acceptable options, typically separated by conjunctions like "or", "and", or commas.
         - Recognize multiple correct answers indicated by terms such as "or", "and/or", etc.
      - **Student Answer Requirements:**
         - The **Student Answer** must include **all** acceptable options present in the **Reference Answer**.
         - **Exact Matching:** Each acceptable option from the Reference Answer must be present **exactly** in the Student Answer, following format and case as per rules.
         - **No Partial Matches:** Providing only a subset of the Reference options is insufficient and should result in a `NO`.
         - **No Extra Options:** Including options not present in the Reference Answer is not allowed and should result in a `NO`.

   **D. Interval Expressions:**
      - **Exact Mathematical Meaning:**
         - The **Student Answer** and the **Reference Answer** must represent **exactly the same interval** with the **same inclusion/exclusion** of endpoints.
         - **Interval Notation Types:**
            - **Open Interval:** `(a, b)` – excludes both endpoints `a` and `b`.
            - **Closed Interval:** `[a, b]` – includes both endpoints `a` and `b`.
            - **Half-Open/Half-Closed Interval:** `(a, b]` or `[a, b)` – includes one endpoint and excludes the other.

   **E. General Rules:**
      - **No Unallowed Approximations or Rounding:** Do not perform any approximation or rounding unless explicitly allowed for decimal numbers as specified above. No rounding/error tolerance for integer is allowed.
      - **Exact Matching:** Ensure that all numerical values, units, expressions, and chemical formulas/equations are **identical** in every required aspect based on their format.
      - **Formatting and Case Sensitivity:** Pay close attention to the formatting, case sensitivity, and punctuation. These factors must **match exactly** where required.
      - **Missing non-essential unit** Missing non-essential unit is allowed for Student's answer. (e.g. Student's answer: 2 and reference answer: 2mol should be considered equivalent)
      - **No Exact Answer** If the Student's answer only provides the process and does not end with an exact answer, output `NO`.
      - **Programming Language** If the Student's answer includes any programming language, e.g. python, Java, etc. output `NO`.

   **F. Caution with 'YES' Output:**
      - **Strict Conditions for 'YES':**
         - Only output `YES` if **all** the above comparison rules are **fully satisfied**.
         - Any slight deviation, even in formatting, numerical precision, or character usage, should **prevent** a `YES` output.
      - **Avoid Overconfidence:**
         - Do not output `YES` unless there is complete certainty that the Student Answer aligns perfectly with the Reference Answer based on the defined rules.
         - If there is any doubt or ambiguity, default to `NO` to maintain accuracy and reliability.

**Output Requirement:**
- Output your thinking process and the final judgement in the following format:
Reason: [Concise 2-3 sentence explanation specifying your thinking process(rounding process if needed) of why the Student's Answer matches the Reference Answer or differs from the Reference Answer]
Judgment: [YES/NO]
            """,
            sample_param={'max_tokens': 1000, "do_sample": False, 'temperature': 0.0},
        ),
    ),
)