import json
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import random
import tqdm
import sys
MODEL_NAME="/yanjianhao/futingwang_test/ModelEditingForDebias/fake_model/counterfact_wiki_llama_4k"

SPLIT=sys.argv[1]

def run_lprobs(model, tokenizer, sentence, device, prefix=None):
    # whether it has sos or eos? GPT2 has neither...

    if prefix:
        new_sentence = prefix.strip() + " " + sentence.strip()
        print(new_sentence)
    else:
        new_sentence = sentence
    
    inps = tokenizer(new_sentence, return_tensors='pt').to(device)
    target_ids = tokenizer(sentence, return_tensors='pt').to(device)
    
    input_ids = inps["input_ids"][:,:-1]
    # targets = shift_left(input_ids, pad=tokenizer.pad_token_id)
    targets_ids = target_ids["input_ids"][:,1:] #remove bos
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=None) # [1, l, V]
        logits = outputs.logits

        lprobs = torch.nn.functional.log_softmax(logits, dim=-1) # [1, l, V]
        lprobs = lprobs[:,-targets_ids.size(1):,:]
        gather_lprobs = lprobs.gather(-1, targets_ids[:,:,None]) # [1, l, 1]
        sent_lprobs = gather_lprobs[0].sum()
    
    return sent_lprobs, gather_lprobs.squeeze().tolist(), targets_ids.squeeze().tolist()


def main():
    with open(f"counterfact/counterfact-{SPLIT}.json", "r") as f:
        data = json.load(f)
    
    model, tok = (
        AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        ),
        AutoTokenizer.from_pretrained(MODEL_NAME),
    )
    tok.pad_token = tok.eos_token
    

    results = []
    
    for d in tqdm.tqdm(data):
        sentence = d["prompt"] + " " + d["ground_truth"]
        subject_start_idx = sentence.lower().find(d["subject"].lower())
        prompt = sentence[: subject_start_idx + len(d["subject"])].strip()
        target_new = sentence[subject_start_idx+ len(d["subject"]):].strip()
        original_target_lprobs, _, _ = run_lprobs(model, tok, " " + target_new, device="cuda", prefix=prompt)
        original_lprobs, original_token_list, token_list = run_lprobs(model, tok, sentence, "cuda")

        paraphrase_sentence = d["rephrase_prompt"] + " " + d["ground_truth"]
        paraphrase_subject_index = paraphrase_sentence.lower().find(d["subject"].lower())

        edit_sample = {
            "case_id": d["case_id"], 
            "prompt": prompt,
            "target_new": target_new,
            "subject": d["subject"],
            "rephrase_prompt":paraphrase_sentence[: paraphrase_subject_index + len(d["subject"])].strip(), 
            "rephrase_target":paraphrase_sentence[paraphrase_subject_index + len(d["subject"]) :].strip(),
            "locality_prompt":d["locality_prompt"],
            "locality_ground_truth":d["locality_ground_truth"],
            "original_lprobs": original_lprobs.item(),
            "original_target_lprobs": original_target_lprobs.item(),
            "original_token_lprobs_list":original_token_list,
            "token_list": token_list
        }
        results.append(edit_sample)
    
    dir = "/yanjianhao/futingwang_test/ModelEditingForDebias/data_construction/outputs/counterfact_llama_fake/"
    if not os.path.exists(dir):
        os.makedirs(dir, exist_ok=True)
    

    with open(os.path.join(dir, f"{SPLIT}.json"), "w") as f:
        json.dump(results, f, indent=4)


    return




if  __name__ == "__main__":
    main()