import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
from transformers import BitsAndBytesConfig
import copy
import json
import os

def load_ppl_model(args):
    hf_model_name = "Qwen/Qwen2.5-14B-Instruct"

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    model = AutoModelForCausalLM.from_pretrained(
        hf_model_name,
        device_map="auto",
        quantization_config=bnb_config,
        cache_dir=args.cache_dir).to('cuda')
    tokenizer = AutoTokenizer.from_pretrained(hf_model_name,cache_dir=args.cache_dir)
    model.eval()

    return model, tokenizer

def compute_ppl(model, tokenizer, full_text, completion):
    inputs = tokenizer(full_text, return_tensors="pt").to('cuda')
    inputs_completion = tokenizer(completion, return_tensors="pt").to('cuda')
    label = copy.deepcopy(inputs['input_ids'])

    label[:,:label.shape[1]-inputs_completion['input_ids'].shape[1]+1] = -100

    with torch.no_grad():
        outputs = model(**inputs, labels=label)
    loss = outputs.loss
    return torch.exp(loss).item()


def get_ppl(args):
    savepath =f"./results/ppl/{args.dataset_name}/{args.model_name}_{args.model_size}/ppl_results.txt"
    os.makedirs(savepath, exist_ok=True)
    savepath = os.path.join(savepath,f"ppl_results.txt")
    if not os.path.exists(savepath):
        with open(savepath, 'w', encoding='utf-8') as f:
            f.write("")
        exists_lines = []
    else:
        with open(savepath, 'r', encoding='utf-8') as f:
            exists_lines = f.readlines()

    exists_keys = set()
    for line in exists_lines:
        key = line.split("::")[0].strip()
        exists_keys.add(key)

    model, tokenizer = load_ppl_model(args)

    with open(f'./results/gen_text/{args.dataset_name}/{args.model_name}_{args.model_size}/prompt.json', 'r', encoding='utf-8') as f:
        prom = json.load(f)

    file_list = os.listdir(f"./results/gen_text/{args.dataset_name}/{args.model_name}_{args.model_size}/")

    for file_name in file_list:
        if "json" not in file_name:
            continue
        if "prompt" in file_name:
            continue

        key = f"{args.dataset_name}_{args.model_name}_{args.model_size}_{file_name}"
        if key in exists_keys:
            continue

        print("Current file... : ",file_name)
        with open(f'./results/gen_text/{args.dataset_name}/{args.model_name}_{args.model_size}/{file_name}', 'r', encoding='utf-8') as f:
            gen_text = json.load(f)
        
        ppl=[]
        for p, text in tqdm(zip(prom, gen_text)):
            ppl_score = compute_ppl(model, tokenizer, p + text, text)
            if ppl_score is None or np.isnan(ppl_score):
                continue
            else:
                ppl.append(ppl_score)

        ppl = sorted(ppl)[:-25]
        mean, std = np.nanmean(ppl), np.nanstd(ppl)
        
        with open(savepath, 'a', encoding='utf-8') as f:
            f.write(f"{key} :: {mean:.2f} ± {std:.2f}\n")