import os
import json
import re
import torch
from tqdm import tqdm
from unsloth import FastLanguageModel
from datasets import load_dataset

MODEL_PATH = "checkpoints/alpaca/Cure-SFT"
OUTPUT_JSONL = "result/alpaca_eval/alpaca/Cure-SFT.jsonl"

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = MODEL_PATH,
    max_seq_length = 2048,
    load_in_4bit = True,
)
FastLanguageModel.for_inference(model)
tokenizer.padding_side = "left" 

dataset = load_dataset("tatsu-lab/alpaca_eval", "alpaca_eval", split="eval")

processed_instructions = set()
if os.path.exists(OUTPUT_JSONL):
    print(f"An existing output file has been detected; reading progress is in progress.")
    with open(OUTPUT_JSONL, "r", encoding="utf-8") as f:
        for line in f:
            try:
                data = json.loads(line)
                processed_instructions.add(data['instruction'])
            except json.JSONDecodeError:
                continue
    print(f"Completed: {len(processed_instructions)}, remaining: {len(dataset) - len(processed_instructions)}")

def aggressive_clean(text):
    text = re.sub(r'<\|.*?\|>', '', text)
    words = text.split()
    if len(words) > 20:
        for i in range(len(words) - 10):
            if words[i:i+4] == words[i+4:i+8] == words[i+8:i+12]:
                return " ".join(words[:i+4])
    return text.strip()

alpaca_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:\n"

os.makedirs(os.path.dirname(OUTPUT_JSONL), exist_ok=True)

from transformers import StoppingCriteria, StoppingCriteriaList

class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [tokenizer.eos_token_id]
        return input_ids[0][-1] in stop_ids
    
with open(OUTPUT_JSONL, "a", encoding="utf-8") as f_out:
    print(f"Starting...")
    for item in tqdm(dataset):
        instruction = item['instruction']
        
        if instruction in processed_instructions:
            continue
            
        prompt = alpaca_prompt.format(instruction)
        inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
        
        with torch.no_grad():
            outputs = model.generate(
                        **inputs,
                        max_new_tokens=512,
                        do_sample=True,         
                        temperature=0.1, 
                        top_p=0.7,
                        repetition_penalty=1.1, 
                        stopping_criteria=StoppingCriteriaList([StopOnTokens()])
                    )
        
        input_len = inputs.input_ids.shape[1]
        response_tokens = outputs[0][input_len:]
        response = tokenizer.decode(response_tokens, skip_special_tokens=True)
        
        clean_response = aggressive_clean(response)
        
        result_entry = {
            "instruction": instruction,
            "output": clean_response,
            "generator": "Cure-SFT",
            "dataset": item["dataset"]
        }
        
        f_out.write(json.dumps(result_entry, ensure_ascii=False) + "\n")
        f_out.flush() 

print(f"Generation complete! Save results to: {OUTPUT_JSONL}")


FINAL_JSON = OUTPUT_JSONL.replace(".jsonl", ".json")
final_results = []
with open(OUTPUT_JSONL, "r", encoding="utf-8") as f:
    for line in f:
        final_results.append(json.loads(line))

with open(FINAL_JSON, "w", encoding="utf-8") as f:
    json.dump(final_results, f, indent=2, ensure_ascii=False)
