from unsloth import FastLanguageModel, vLLMSamplingParams
from unsloth import is_bfloat16_supported
import wandb
import argparse
import re
from datasets import load_dataset, Dataset
from trl import GRPOConfig, GRPOTrainer
from functools import partial
import jieba
import random
import os
import sys
pwd_path=os.path.dirname(__file__)
root_path=os.path.join(pwd_path, "../../")
sys.path.insert(0, os.path.join(pwd_path, "../../"))
sys.path.insert(0, pwd_path)

from reward import Reward

def parse_args():
    parser = argparse.ArgumentParser(description="解析模型名称")
    parser.add_argument("--model_name", required=True)
    parser.add_argument("--output_dir", required=True)
    parser.add_argument("--chat_template", required=True)
    parser.add_argument("--wandb_project", required=True)
    parser.add_argument("--wandb_runname", required=True)
    parser.add_argument("--hint_level", default="none")
    parser.add_argument("--solution_hint_mask_strategy", type= str, default="sentence", choices=["sentence", "word", "char"])
    parser.add_argument("--solution_hint_mask_ratio", type= float, default=0.)
    # 新增参数
    parser.add_argument('--full_finetuning', action='store_true', default=False)
    parser.add_argument("--load_precision", default="4bit", choices=["4bit", "8bit", "16bit"])
    parser.add_argument("--lora_rank", type=int, default=32)
    parser.add_argument("--max_prompt_length", type=int, default=2*1024)
    parser.add_argument("--max_completion_length", type=int, default=8*1024) 
    parser.add_argument("--gpu_memory_utilization", type=float, default=0.4) 
    parser.add_argument("--data_files", nargs='+', required=True)
    parser.add_argument("--num_generations", type=int, default=8)
    parser.add_argument("--reward_funcs", nargs='+', required=True, 
                       choices=['format', 'correctness', 'length'])
    args=parser.parse_args()

    return args

def mask_solution(solution, strategy, ratio):
    if ratio <= 0:
        return solution

    if strategy == 'sentence':
        # 支持中英文标点的正则表达式
        sentences = re.split(r'(?<=[.!?。:：；;，,！？])\s*', solution)
        if len(sentences)>5:
            mask_begin=3
            mask_end=-2
        else:
            mask_begin = min(1, len(sentences)-3)
            mask_end = -min(1, len(sentences)-3)
        sentences_reserved=(sentences[:mask_begin], sentences[-mask_end:])
        sentences_masked=sentences[mask_begin:-mask_end]
        masked = []
        for s in sentences_masked:
            if random.random() < ratio:
                pass
            else:
                masked.append(s)
        masked=sentences_reserved[0]+masked+sentences_reserved[1]
        return ''.join(masked)  # 保留原始标点和空格格式

    elif strategy == 'word':
        # 检测是否包含中文字符
        has_chinese = any('\u4e00' <= c <= '\u9fff' for c in solution)
        if has_chinese:
            words = jieba.lcut(solution)  # 中文使用结巴分词
        else:
            words = solution.split()  # 英文保持空格分割
        
        masked = []
        for w in words:
            if random.random() < ratio:
                pass
            else:
                masked.append(w)
        return ' '.join(masked)  # 用空格连接分词结果

    elif strategy == 'char':
        masked = []
        for c in solution:
            if random.random() < ratio:
                pass
            else:
                masked.append(c)
        return ''.join(masked)

    else:
        raise ValueError(f"未知的mask策略: {strategy}")
    


def prepare_reward_funcs(args, tokenizer):
    selected_reward_funcs=Reward(reward_funcs=args.reward_funcs, 
                                 tokenizer=tokenizer,
                                 max_completion_length=args.max_completion_length,
                                 ).selected_reward_funcs
    return selected_reward_funcs

def prepare_model_tokenizer(args):
    # 加载模型
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = args.model_name,
        max_seq_length = args.max_prompt_length + args.max_completion_length,
        load_in_4bit = args.load_precision == "4bit",
        load_in_8bit = args.load_precision == "8bit",
        full_finetuning = args.full_finetuning,
        fast_inference = True,
        gpu_memory_utilization = args.gpu_memory_utilization,  # 0.4 for 1.5B
        # float8_kv_cache = True, for H100
    )

    # 处理模板
    if args.chat_template=="deepseek":
        chat_template_with_think_wo_generation_prompt="{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<｜User｜>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<｜Assistant｜><｜tool▁calls▁begin｜><｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<｜tool▁call▁end｜>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<｜tool▁call▁end｜>'}}{{'<｜tool▁calls▁end｜><｜end▁of▁sentence｜>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<｜tool▁outputs▁end｜>' + message['content'] + '<｜end▁of▁sentence｜>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{{'<｜Assistant｜>' + content + '<｜end▁of▁sentence｜>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<｜tool▁outputs▁begin｜><｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<｜tool▁outputs▁end｜>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<｜Assistant｜>'}}{% endif %}"

        tokenizer.chat_template=chat_template_with_think_wo_generation_prompt

    # 创建LoRA模型
    if not args.full_finetuning:
        model = FastLanguageModel.get_peft_model(
            model,
            r = args.lora_rank,
            target_modules = [
                "q_proj", "k_proj", "v_proj", "o_proj",
                "gate_proj", "up_proj", "down_proj",
            ],
            lora_alpha = 64,
            use_gradient_checkpointing = "unsloth",
            random_state = 3407,
        )
    return model, tokenizer



def get_dataset(args, tokenizer):
    def cal_token_length(prompt):
        input_ids = tokenizer.apply_chat_template(
            prompt,
            add_generation_prompt=True  # 根据模型需求调整参数
        )
        return len(input_ids)
    
    def formatting_prompts_func(example):
        question=example["problem"]
        answer=example["answer"]
        question_prompt="\nPlease reason step by step, and put your final answer within \\boxed{}."
        if args.hint_level=="answer":
            question_prompt=question_prompt+f"\nGive you a hint, the answer of this problem is {answer}"
        elif args.hint_level=="solution":
            solution=example.get("solution", "").strip()
            solution = mask_solution(solution, strategy=args.solution_hint_mask_strategy, ratio=args.solution_hint_mask_ratio)
            assert solution, print(f"args.hint_level==solution, however your dataset do not have solution!")
            question_prompt=question_prompt+f"\nGive you a hint, the answer of this problem is {answer}. A example solution is {solution}."
        else:
            question_prompt=question_prompt
        if args.chat_template == "deepseek":
            messages=[
                {"role" : "user", "content" : question + question_prompt},
            ]
        elif args.chat_template == "qwen":
            messages=[
                {"role": "system", "content": "You are a helpful assistant!"},
                {"role" : "user", "content" : question + question_prompt},
            ]
        input_length=cal_token_length(messages)
        return {"prompt" : messages, "answer": answer, "input_length":input_length}

    dataset = load_dataset("json", data_files=args.data_files, split="train")
    print(dataset.column_names)
    dataset = dataset.map(formatting_prompts_func, num_proc=64, batched = False, remove_columns=dataset.column_names)
    print(f"数据数量：{len(dataset)}")
    dataset = dataset.filter(lambda x: x['input_length'] <= args.max_prompt_length, num_proc=64)
    print("过滤后的数据列名:", dataset.column_names)
    print(f"过滤后的训练数据数量：{len(dataset)} max_prompt_length: {args.max_prompt_length}")
    return dataset

if __name__=="__main__":
    args = parse_args()
    args.output_dir=os.path.join(root_path, args.output_dir)
    # 初始化wandb
    wandb.init(project=args.wandb_project, name=args.wandb_runname)

    model, tokenizer = prepare_model_tokenizer(args)
    train_dataset = get_dataset(args, tokenizer)
    reward_funcs = prepare_reward_funcs(args, tokenizer)

    vllm_sampling_params = vLLMSamplingParams(
        top_p=0.95,
        # temperature=0.6,    # already set in GRPOConfig
    )

    # 训练配置
    training_args = GRPOConfig(
        use_vllm = True, # use vLLM for fast inference!
        temperature=0.6,
        vllm_sampling_params = vllm_sampling_params,
        learning_rate = 1e-6,
        adam_beta1 = 0.9,
        adam_beta2 = 0.99,
        weight_decay = 0.1,
        warmup_ratio = 0.1,
        lr_scheduler_type = "cosine",
        optim = "paged_adamw_8bit",
        logging_steps = 1,
        bf16 = is_bfloat16_supported(),
        fp16 = not is_bfloat16_supported(),
        per_device_train_batch_size = args.num_generations,   # max(per_device_train_batch_size, args.num_generations)
        gradient_accumulation_steps = 4, # Increase to 4 for smoother training
        num_generations = args.num_generations, # Decrease if out of memory
        max_prompt_length = args.max_prompt_length,
        max_completion_length = args.max_completion_length,
        num_train_epochs = 1, # Set to 1 for a full training run
        # max_steps = 250,
        save_steps = 500,
        save_total_limit = 10,
        max_grad_norm = 0.1,
        report_to = "wandb", # wandb Can use Weights & Biases
        output_dir = args.output_dir
    )

    # 创建训练器
    trainer = GRPOTrainer(
        model = model,
        processing_class = tokenizer,
        reward_funcs = reward_funcs,  # 使用映射后的奖励函数
        args = training_args,
        train_dataset = train_dataset,
    )

    # 开始训练
    trainer.train()