import os
import json
import re
import glob
import torch
import argparse 
from tqdm import tqdm
from unsloth import FastLanguageModel

MODEL_PATH = "checkpoints/alpaca/Cure-SFT/cure-sft"
INPUT_DIR = "test_dataset"
OUTPUT_BASE_DIR = "result/evaluation_set/alpaca/Cure-SFT"


MODEL_MAX_SEQ_LENGTH = 2048 
MAX_NEW_TOKENS = 1024 

TEMP = 0.2          
TOP_P = 1.0
REP_PENALTY = 1.0   

MAX_INPUT_TOKENS = MODEL_MAX_SEQ_LENGTH - MAX_NEW_TOKENS 


model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = MODEL_PATH,
    max_seq_length = MODEL_MAX_SEQ_LENGTH,
    load_in_4bit = True,
)
FastLanguageModel.for_inference(model)

tokenizer.padding_side = "left" 
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id


def aggressive_clean(text):
    text = re.sub(r'<\|.*?\|>', '', text)
    return text.strip()

def get_instruction(item: dict):
    if isinstance(item.get("Instruction"), str): return item["Instruction"].strip()
    if isinstance(item.get("instruction"), str) and "instances" in item:
        base = item["instruction"].strip()
        insts = item.get("instances", [])
        if insts and isinstance(insts[0], dict) and isinstance(insts[0].get("input"), str):
            return base + "\n\nInput:\n" + insts[0]["input"].strip()
        return base
    if isinstance(item.get("prompt"), str): return item["prompt"].strip()
    if isinstance(item.get("text"), str): return item["text"].strip()
    if "conversations" in item:
        conv = item["conversations"]
        if isinstance(conv, str): return conv.strip()
        if isinstance(conv, list):
            for turn in conv:
                if isinstance(turn, dict) and turn.get("from") in ["human", "user"]:
                    return turn.get("value", "").strip()
    return None

def smart_truncate(prompt, tokenizer, max_input_len):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    curr_len = input_ids.shape[1]
    if curr_len > max_input_len:
        truncated_ids = input_ids[:, -max_input_len:]
        prompt = tokenizer.decode(truncated_ids[0], skip_special_tokens=True)
    return prompt


def process_file(input_file):
    filename = os.path.basename(input_file)
    output_file = os.path.join(OUTPUT_BASE_DIR, filename)
    
    print(f"\nProcessing: {filename}")
    
    if not os.path.exists(input_file):
        print(f"Error: File does not exist -> {input_file}")
        return

    with open(input_file, 'r', encoding='utf-8') as f:
        try:
            input_data = [json.loads(line) for line in f if line.strip()]
        except json.JSONDecodeError:
            f.seek(0)
            input_data = json.load(f)

    processed_indices = set()
    if os.path.exists(output_file):
        with open(output_file, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line)
                    if "idx" in data:
                        processed_indices.add(int(data["idx"]))
                except: 
                    continue
        print(f"Finished: {len(processed_indices)} / {len(input_data)}")

    if len(processed_indices) == len(input_data):
        print(f"{filename} Completed, skip.")
        return

    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    
    alpaca_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:\n"
    
    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]
    
    with open(output_file, "a", encoding="utf-8") as f_out:
        for idx, item in enumerate(tqdm(input_data, desc=filename)):
            
            if idx in processed_indices:
                continue

            instruction = get_instruction(item)
            if not instruction: continue

            prompt = alpaca_prompt.format(instruction)
            prompt = smart_truncate(prompt, tokenizer, MAX_INPUT_TOKENS)
            
            inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
            
            try:
                with torch.no_grad():
                    outputs = model.generate(
                        **inputs,
                        max_new_tokens=MAX_NEW_TOKENS,
                        do_sample=True,
                        temperature=TEMP,
                        top_p=TOP_P,
                        repetition_penalty=REP_PENALTY,
                        eos_token_id=terminators,
                        pad_token_id=tokenizer.pad_token_id,
                    )
            except Exception as e:
                print(f"⚠️ Error skipping idx {idx}: {e}")
                torch.cuda.empty_cache()
                continue
            
            input_len = inputs.input_ids.shape[1]
            response_tokens = outputs[0][input_len:]
            response = tokenizer.decode(response_tokens, skip_special_tokens=True)
            clean_response = aggressive_clean(response)
            
            result_entry = {
                "idx": idx,
                "instruction": instruction,
                "output": clean_response,
                "generator": "Cure-SFT",
                "dataset": filename.replace(".jsonl", "")
            }
            
            f_out.write(json.dumps(result_entry, ensure_ascii=False) + "\n")
            f_out.flush()
            
    print(f"{filename} Finished")

def main():
    parser = argparse.ArgumentParser(description="Evaluation Script")
    parser.add_argument("--dataset", type=str, default=None, 
                        help="Specify the single filename to be evaluated (e.g., vicuna_eval.jsonl). If not specified, all files under INPUT_DIR will be evaluated.")
    
    args = parser.parse_args()

    if args.dataset:
        target_path = os.path.join(INPUT_DIR, args.dataset)
        if not os.path.exists(target_path):
            if os.path.exists(args.dataset):
                target_path = args.dataset
            else:
                print(f"Error: File not found.")
                return
        
        print(f"Only evaluate the specified files: {target_path}")
        process_file(target_path)
        
    else:
        print(f"Batch evaluate all files under {INPUT_DIR}")
        files = glob.glob(os.path.join(INPUT_DIR, "*.jsonl"))
        if not files:
            print(f"File not found: {INPUT_DIR}")
            return

        for file_path in files:
            process_file(file_path)

if __name__ == "__main__":
    main()