import json
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import random
import tqdm
import sys
# MODEL_NAME="/yanjianhao/futingwang_test/ModelEditingForDebias/fake_model/counterfact_wiki_llama_4k"
MODEL_NAME="/yanjianhao/huggingface/llama2-7b"
SPLIT=sys.argv[1]

def run_lprobs(model, tokenizer, sentence, device, prefix=None):
    # whether it has sos or eos? GPT2 has neither...

    if prefix:
        new_sentence = prefix.strip() + " " + sentence.strip()
        print(new_sentence)
    else:
        new_sentence = sentence
    
    inps = tokenizer(new_sentence, return_tensors='pt').to(device)
    target_ids = tokenizer(sentence, return_tensors='pt').to(device)
    
    input_ids = inps["input_ids"][:,:-1]
    # targets = shift_left(input_ids, pad=tokenizer.pad_token_id)
    targets_ids = target_ids["input_ids"][:,1:] #remove bos
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=None) # [1, l, V]
        logits = outputs.logits

        lprobs = torch.nn.functional.log_softmax(logits, dim=-1) # [1, l, V]
        lprobs = lprobs[:,-targets_ids.size(1):,:]
        gather_lprobs = lprobs.gather(-1, targets_ids[:,:,None]) # [1, l, 1]
        sent_lprobs = gather_lprobs[0].sum()
    # print(sent_lprobs > torch.log(torch.tensor(0.2)))
    return sent_lprobs, gather_lprobs.squeeze().tolist(), targets_ids.squeeze().tolist(), sent_lprobs > torch.log(torch.tensor(0.2))




def main():
    # with open(f"counterfact/counterfact.json", "r") as f:
    #     data = json.load(f)
    with open(f"counterfact/counterfact_filtered_paraphrases.json", "r") as f:
        data = json.load(f)

    with open(f"counterfact/counterfact-{SPLIT}.json", "r")as f:
        original_ids = json.load(f)
    ids = [o["case_id"] for o in original_ids]
   
    filtered_samples = [sample for sample in data if sample["case_id"] in ids]


    
    model, tok = (
        AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        ),
        AutoTokenizer.from_pretrained(MODEL_NAME),
    )
    tok.pad_token = tok.eos_token
    

    results = []
    
    for d in tqdm.tqdm(filtered_samples):
    
        prompt = d["requested_rewrite"]["prompt"].format(d["requested_rewrite"]["subject"])

        target = d["requested_rewrite"]["target_true"]["str"]
        sentence = prompt + " " + target
        
        original_target_lprobs, _, _, filtered = run_lprobs(model, tok, target, device="cuda", prefix=prompt)
        print(filtered)
        if filtered:
            original_lprobs, original_token_list, token_list, _ = run_lprobs(model, tok, sentence, "cuda")

            # paraphrase_sentence = d["rephrase_prompt"] + " " + d["ground_truth"]
            # paraphrase_subject_index = paraphrase_sentence.lower().find(d["subject"].lower())

            edit_sample = {
                "case_id": d["case_id"], 
                "prompt": prompt,
                "target_new": target,
                "subject": d["requested_rewrite"]["subject"],
                "rephrase_prompt":d["paraphrase_prompts"], 
                "rephrase_target":target,
                "locality_prompt":d["neighborhood_prompts"],
                "locality_ground_truth":target,
                "original_lprobs": original_lprobs.item(),
                "original_target_lprobs": original_target_lprobs.item(),
                "original_token_lprobs_list":original_token_list,
                "token_list": token_list
            }
            results.append(edit_sample)
    
    dir = "/yanjianhao/futingwang_test/ModelEditingForDebias/data_construction/outputs/more_paraphrase/counterfact_llama_structured_filtered/"
    if not os.path.exists(dir):
        os.makedirs(dir, exist_ok=True)
    

    with open(os.path.join(dir, f"{SPLIT}.json"), "w") as f:
        json.dump(results, f, indent=4)


    return




if  __name__ == "__main__":
    main()