import os
import json
import torch
import copy
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,
    set_seed
)
from datasets import Dataset
from peft import LoraConfig, get_peft_model, TaskType
from tqdm import tqdm

set_seed(42)

templates = [
    "{word}中含有{char}的个数是{n}。",
    "{word}中包含的{char}字母数量是{n}。",
    "字母{char}在{word}中出现的次数是{n}。",
    "{word}中字母为{char}的数量是{n}。",
    "{word}中{char}的统计数量是{n}。",
    "{word}里{char}的数量是{n}。",
    "{word}这个单词中{char}字母出现的次数是{n}。",
    "{word}中字符{char}的出现次数是{n}。",
    "{char}在单词{word}中出现的次数是{n}。",
    "{word}中{char}的个数是{n}。",
    
    "{word}中总共包含{n}个{char}字母。",
    "在{word}这个词中，{char}出现了{n}次。",
    "{word}里一共有{n}个字母是{char}。",
    "字母{char}在{word}中总共出现了{n}次。",
    "{word}中含有{n}个{char}。",
    "{char}这个字母在{word}中出现了{n}次。",
    "经过统计，{word}中{char}的数量为{n}。",
    "{word}中{char}字符的出现频率是{n}次。",
    "可以确定{word}中{char}有{n}个。",
    "根据计算，{word}包含{n}个{char}。",

    "{word}中{char}的数目是{n}。",
    "单词{word}里{char}共出现{n}次。",
    "{char}在{word}中的出现次数为{n}。",
    "{word}中与{char}相同的字母有{n}个。",
    "在{word}中找到{n}个{char}。",
    "{word}由{n}个{char}构成其中的一部分。",
    "在{word}中，字母{char}重复了{n}次。",
    "对{word}进行分析，发现{char}出现了{n}次。",
    "{word}中{char}的出现总数是{n}。",
    "记录显示{word}中{char}出现了{n}次。",

    "{word}中出现了{n}次{char}字母。",
    "{char}作为字符在{word}中出现了{n}次。",
    "从{word}中可数出{n}个{char}。",
    "{word}内含有{n}个{char}。",
    "在{word}中，{char}共计出现{n}次。",
    "{word}中确切地有{n}个{char}。",
    "字母{char}在{word}中被发现{n}次。",
    "{word}中存在{n}个{char}。",
    "在{word}中，包含{n}个{char}。",
    "{char}在{word}中作为成员出现了{n}次。",

    "{word}中匹配{char}的次数是{n}。",
    "检测到{word}中有{n}个{char}。",
    "识别结果：{word}中{char}出现{n}次。",
    "{word}中{char}的频次是{n}。",
    "通过计数得出{word}中{char}有{n}个。",
    "经核查，{word}中{char}的数量为{n}。",
    "{word}中{char}的实例数是{n}。",
    "在{word}字符串中，{char}出现了{n}次。",
    "{word}中{char}的分布数量为{n}。",
    "分析结果显示{word}中{char}有{n}个。",
    
]

model_path = "/path/to/DeepSeek-R1-Distill-Qwen-14B"
data_path = "/path/to/data"
output_dir_base = "/path/to/LongCoT/wordCount/outputs"
os.makedirs(output_dir_base, exist_ok=True)
results_log_file = os.path.join(output_dir_base, "per_sample.json")

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    padding_side='left',
    trust_remote_code=True
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    low_cpu_mem_usage=True,
    local_files_only=True,
    trust_remote_code=True,
    use_cache=False,
)

lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

def preprocess_function(examples, tokenizer, prompt_field="question", response_field="answer"):
    prompts = examples[prompt_field]
    responses = examples[response_field]
    full_texts = [p + r for p, r in zip(prompts, responses)]

    tokenized = tokenizer(
        full_texts,
        truncation=True,
        max_length=512,
        padding=False,
        return_attention_mask=False,
        return_offsets_mapping=False,
        add_special_tokens=False, 
    )

    input_ids = tokenized["input_ids"]
    labels = []
    for i, (prompt, inp_ids) in enumerate(zip(prompts, input_ids)):
        prompt_tokens = tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids[0]
        prompt_len = len(prompt_tokens)
        label = [-100] * prompt_len + inp_ids[prompt_len:]
        labels.append(label)

    return {
        "input_ids": input_ids,
        "labels": labels,
        "attention_mask": [[1] * len(ids) for ids in input_ids]
    }

def generate_response(model, tokenizer, question):
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(
            question, 
            return_tensors="pt", 
            truncation=True, 
            max_length=512, 
            add_special_tokens=False
        ).to(model.device)
        outputs = model.generate(
            **inputs,
            max_new_tokens=4096,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
        response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=False)
    return response.strip()

with open(data_path, 'r', encoding='utf-8') as f:
    raw_data = json.load(f)

all_results = []

for idx, item in tqdm(enumerate(raw_data)):
    train_data = []
    for temp in templates:
        # question_template = f"<｜User｜>{temp.format(word=item['word'], char=item['char'])}<｜Assistant｜><think>\n\n</think>"
        # answer_target = f"\n{item['answer']}<｜end▁of▁sentence｜>"
        question_template = ""
        answer_target = temp.format(word=item['word'], char=item['char'], n=item['answer'])
    
        train_data.append({
            'question': question_template,
            'answer': answer_target
        })

    dataset = Dataset.from_list(train_data)
    tokenized_dataset = dataset.map(
        lambda x: preprocess_function(x, tokenizer),
        batched=True,
        remove_columns=["question", "answer"],
        num_proc=1,
        desc=f"Preprocessing sample {idx}"
    )

    # model = copy.deepcopy(base_model)
    model = base_model
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    per_device_batch_size = 1
    gradient_accumulation_steps = 2

    training_args = TrainingArguments(
        output_dir="trainer_output",
        num_train_epochs=1,
        per_device_train_batch_size=per_device_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=1e-4,
        weight_decay=0.01,
        warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        logging_steps=1,
        save_strategy="no",
        report_to="none",
        bf16=True,
        max_grad_norm=1.0,
        remove_unused_columns=True,
        optim="adamw_torch",
        dataloader_num_workers=0,
    )
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        padding="longest",
        pad_to_multiple_of=8,
        return_tensors="pt"
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
    )
    trainer.train()

    test_prompt = f"<｜User｜>{item['word']}这个单词里面有几个字母{item['char']}？直接用一个阿拉伯数字回答问题。<｜Assistant｜><think>"
    generated_output = generate_response(trainer.model, tokenizer, test_prompt)
    result_entry = {
        "index": idx,
        "question": f"{item['word']}这个单词里面有几个字母{item['char']}？直接用一个阿拉伯数字回答问题。",
        "original_output": item['model_answer'],
        "generated_output": generated_output,
    }
    all_results.append(result_entry)

    with open(results_log_file, "w", encoding="utf-8") as f_out:
        json.dump(all_results, f_out, ensure_ascii=False, indent=4)

