import json
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import random
import tqdm
import sys
# MODEL_NAME="/yanjianhao/futingwang_test/ModelEditingForDebias/fake_model/counterfact_wiki_llama_4k"
MODEL_NAME="/yanjianhao/huggingface/gpt-j-6b"
SPLIT=sys.argv[1]

def run_lprobs(model, tokenizer, sentence, device, prefix=None):
    # whether it has sos or eos? GPT2 has neither...

    if prefix:
        new_sentence = prefix.strip() + " " + sentence.strip()
        print(new_sentence)
    else:
        new_sentence = sentence
    
    inps = tokenizer(new_sentence, return_tensors='pt').to(device)
    target_ids = tokenizer(sentence, return_tensors='pt').to(device)
    
    input_ids = inps["input_ids"][:,:-1]
    # targets = shift_left(input_ids, pad=tokenizer.pad_token_id)
    targets_ids = target_ids["input_ids"][:,1:] #remove bos
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=None) # [1, l, V]
        logits = outputs.logits

        lprobs = torch.nn.functional.log_softmax(logits, dim=-1) # [1, l, V]
        lprobs = lprobs[:,-targets_ids.size(1):,:]
        gather_lprobs = lprobs.gather(-1, targets_ids[:,:,None]) # [1, l, 1]
        sent_lprobs = gather_lprobs[0].sum()
    
    return sent_lprobs, gather_lprobs.squeeze().tolist(), targets_ids.squeeze().tolist()


def main():
    with open(f"counterfact/counterfact-{SPLIT}.json", "r") as f:
        data = json.load(f)
    
    model, tok = (
        AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        ),
        AutoTokenizer.from_pretrained(MODEL_NAME),
    )
    tok.pad_token = tok.eos_token
    

    results = []
    
    for d in tqdm.tqdm(data):
        sentence = d["prompt"] + " " + d["ground_truth"]
        
        original_target_lprobs, _, _ = run_lprobs(model, tok, d["ground_truth"], device="cuda", prefix=d["prompt"])
        original_lprobs, original_token_list, token_list = run_lprobs(model, tok, sentence, "cuda")

        # paraphrase_sentence = d["rephrase_prompt"] + " " + d["ground_truth"]
        # paraphrase_subject_index = paraphrase_sentence.lower().find(d["subject"].lower())

        edit_sample = {
            "case_id": d["case_id"], 
            "prompt": d["prompt"],
            "target_new": d["ground_truth"],
            "subject": d["subject"],
            "rephrase_prompt":d["rephrase_prompt"], 
            "rephrase_target":d["ground_truth"],
            "locality_prompt":d["locality_prompt"],
            "locality_ground_truth":d["locality_ground_truth"],
            "original_lprobs": original_lprobs.item(),
            "original_target_lprobs": original_target_lprobs.item(),
            "original_token_lprobs_list":original_token_list,
            "token_list": token_list
        }
        results.append(edit_sample)
    
    dir = "/yanjianhao/futingwang_test/ModelEditingForDebias/data_construction/outputs/counterfact_llama_structured_gptj/"
    if not os.path.exists(dir):
        os.makedirs(dir, exist_ok=True)
    

    with open(os.path.join(dir, f"{SPLIT}.json"), "w") as f:
        json.dump(results, f, indent=4)


    return




if  __name__ == "__main__":
    main()