# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import shutil

from accelerate import PartialState
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    HfArgumentParser,
    AutoConfig
)


from trl import ModelConfig, RLOOConfig, RLOOTrainer, ScriptArguments, RLOOTrainerFeedback_v2 as RLOOTrainerFeedback, RLOOConfigWithFeedback
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

"""
python -i examples/scripts/rloo/rloo.py \
    --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
    --dataset_train_split descriptiveness \
    --learning_rate 3e-6 \
    --num_ppo_epochs 1 \
    --num_mini_batches 1 \
    --output_dir models/minimal/ppo \
    --per_device_train_batch_size 64 \
    --gradient_accumulation_steps 1 \
    --total_episodes 10000 \
    --model_name_or_path EleutherAI/pythia-1b-deduped \
    --missing_eos_penalty 1.0

accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
    examples/scripts/rloo/rloo.py \
    --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
    --dataset_train_split descriptiveness \
    --output_dir models/minimal/rloo \
    --rloo_k 2 \
    --num_ppo_epochs 1 \
    --num_mini_batches 1 \
    --learning_rate 3e-6 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --total_episodes 10000 \
    --model_name_or_path EleutherAI/pythia-1b-deduped \
    --sft_model_path EleutherAI/pythia-1b-deduped \
    --reward_model_path EleutherAI/pythia-1b-deduped \
    --local_rollout_forward_batch_size 1 \
    --missing_eos_penalty 1.0
"""


if __name__ == "__main__":
    parser = HfArgumentParser((ScriptArguments, RLOOConfigWithFeedback, ModelConfig)) # the RLOOConfig's grandfather class is ScriptArguments
    script_args, training_args, model_args = parser.parse_args_into_dataclasses()
    # setattr(training_args, "rm_with_feedback", False) # 添加控制参数
    # setattr(training_args, "rm_lr", 1e-6) # 添加控制参数
    
    # remove output_dir if exists
    # shutil.rmtree(training_args.output_dir, ignore_errors=True)

    ################
    # Model & Tokenizer
    ################
    # pad query at left (in trainer, pad res at right)
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code
    ) 
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    if tokenizer.chat_template is None:
        tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
    

    # policy = AutoModelForCausalLM.from_pretrained(
    #     training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code
    # )

    # ref_policy = AutoModelForCausalLM.from_pretrained(
    #     training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code
    # )

    policy_config=AutoConfig.from_pretrained(training_args.sft_model_path)

    if training_args.debug:
        policy_config.num_hidden_layers=2

    policy = AutoModelForCausalLM.from_pretrained(
        training_args.sft_model_path, 
        trust_remote_code=model_args.trust_remote_code, 
        config=policy_config
    )
    ref_policy = AutoModelForCausalLM.from_pretrained(
        training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code, config=policy_config
    )


    # TODO use the config to init the rm 
    rm_config=AutoConfig.from_pretrained(training_args.reward_model_path)
    rm_config.policy_hidden_size=policy.config.hidden_size
    rm_config.num_labels=1
    rm_config.agg=training_args.agg
    rm_config.fw=training_args.fw
    rm_config.lqh=training_args.lqh
    rm_config.enable_lm=training_args.enable_lm
    rm_config.mlp_hidden_size=4096

    if training_args.debug:
        rm_config.num_hidden_layers=2

    if "pythia" in training_args.reward_model_path.lower():
        from trl.modified_reward_model.modified_reward_modeling_pythia import ModifiedRewardModel
    elif "llama" in training_args.reward_model_path.lower():
        from trl.modified_reward_model.modified_reward_modeling import ModifiedRewardModel
    elif "qwen" in training_args.reward_model_path.lower():
        from trl.modified_reward_model.modified_reward_modeling_qwen import ModifiedRewardModel

    reward_model =ModifiedRewardModel.from_pretrained(
        training_args.reward_model_path, 
        trust_remote_code=model_args.trust_remote_code, 
        config=rm_config
    )

    rm_tokenizer=AutoTokenizer.from_pretrained(
        training_args.reward_model_path, padding_side="left", trust_remote_code=model_args.trust_remote_code
    )
    rm_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    if rm_tokenizer.chat_template is None:
        rm_tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE


    if training_args.gradient_checkpointing:
        policy.gradient_checkpointing_enable()
        reward_model.gradient_checkpointing_enable()


    ################
    # Dataset
    ################
    if "ultrafeedback" in script_args.dataset_name:
        dataset = load_dataset('HuggingFaceH4/ultrafeedback_binarized', split="train_prefs", cache_dir="/path/to/datasets")
        # TODO load the eval dataset
        eval_dataset = load_dataset('HuggingFaceH4/ultrafeedback_binarized', split="test_prefs", cache_dir="/path/to/datasets")

        if training_args.debug:
            dataset=dataset.select(range(512))

        # TODO filter the long datas, load the eval dataset
        if training_args.filter_length >0 :
            dataset = dataset.filter(lambda example: len(example["chosen"][1]["content"]) <= training_args.filter_length)
            print(f"filtered train sample nums: {len(dataset)}")
            eval_dataset = eval_dataset.filter(lambda example: len(example["chosen"][1]["content"]) <= training_args.filter_length)
            print(f"filtered eval sample nums: {len(eval_dataset)}")
    else:
        dataset = load_dataset(
            script_args.dataset_name, name=script_args.dataset_config, split=script_args.dataset_train_split
        )
    train_dataset = dataset
    eval_dataset = eval_dataset

    dataset_text_field = "prompt"

    def prepare_dataset(dataset, tokenizer):
        """pre-tokenize the dataset before training; only collate during training"""

        def tokenize(element):

            # process no prompt case, add the dataset_text_field 
            if dataset_text_field not in element:
                outputs = tokenizer(
                    [item[0]["content"] for item in element["context_messages"]],
                    padding=False,
            )
            else:
                # TODO  add chat template 
                # outputs = tokenizer(
                #     element[dataset_text_field],
                #     padding=False,
                # )
                batch_query = [tokenizer.apply_chat_template([{"role": "user", "content":x}], tokenize=False) for x in element[dataset_text_field]]
                outputs = tokenizer(
                    batch_query,
                    padding=False,
                )
            return {"input_ids": outputs["input_ids"]}
        

        return dataset.map(
            tokenize,
            batched=True,
            remove_columns=dataset.column_names,
            num_proc=training_args.dataset_num_proc,
            load_from_cache_file=False
        )

    # Compute that only on the main process for faster data processing.
    # see: https://github.com/huggingface/trl/pull/1255
    with PartialState().local_main_process_first():
        train_dataset = prepare_dataset(train_dataset, tokenizer)
        eval_dataset = prepare_dataset(eval_dataset, tokenizer)

    ################
    # Training
    ################
    trainer = RLOOTrainerFeedback(
        config=training_args,
        processing_class=tokenizer,
        policy=policy,
        ref_policy=ref_policy,
        reward_model=reward_model,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        rm_processing_class=rm_tokenizer
    )
    trainer.train()

    # Save and push to hub
    trainer.save_model(training_args.output_dir)
    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)

    trainer.generate_completions()
