from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import datasets 
import pickle
import argparse 
def get_pipeline(model_name_or_path):
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path,cache_dir = "/data/user_data/gghosal/cache").cuda().eval()
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,cache_dir = "/data/user_data/gghosal/cache")
    pipe = pipeline('text-generation', model = model, tokenizer = tokenizer, device = "cuda:0")
    return pipe
def get_data(data_path):
    ds = datasets.load_from_disk(data_path)
    return ds
def prepare_shards_llama(ds):
    batch_inputs, answers = [],[]
    for i in range(len(ds)):
        batch_inputs.append([{"role":"system", "content":"You are a knowledgeable assistant. Provide your final response to the question (after all explanation) after the tag 'ANSWER:'."},{"role":"user", "content":ds[i]['question']}])
        answers.append(ds[i]['answer'])
    return batch_inputs, answers
def prepare_shards_deepseek(ds):
    batch_inputs, answers = [],[]
    for i in range(len(ds)):
        batch_inputs.append([{"role":"user", "content":(ds[i]['question']+" Write your final answer after 'ANSWER:'")}])
        answers.append(ds[i]['answer'])
    return batch_inputs, answers
def evaluate_cf(pipeline, ds, nshards, deepseek=False,bsize=16):
    ##Have to use system prompt for Llama
    total = 0
    records = []
    correct = 0
    for shard_idx in range(nshards):
        this_shard = ds.shard(num_shards=nshards, index=shard_idx, contiguous=True)
        if deepseek:
            inps, ans = prepare_shards_deepseek(this_shard)
            answer_gen = pipe(inps,max_length=8000, batch_size = bsize)
            for i in range(len(ans)):
                correct += int(ans[i].lower() in answer_gen[i][0]['generated_text'][-1]['content'].split("ANSWER: ")[-1][-40:].lower())
                total+= 1
                records.append({'question':this_shard[i][['question']],
                                 "correct_response":ans[i], 
                                 "answered_correctly": bool(int(ans[i].lower() in answer_gen[i][0]['generated_text'][-1]['content'].split("ANSWER: ")[-1][-40:].lower())),
                                 "model output":answer_gen[i][0]['generated_text'][-1]['content'],
                                 "extracted_answer": answer_gen[i][0]['generated_text'][-1]['content'].split("ANSWER: ")[-1][-40:].lower(),
                                 "relevant":this_shard[i]['relevant_cf'],
                                 "cf_hop":this_shard[i]['relevant_cf'],
                                 "hop_1_rel": this_shard[i]['hop_1_rel'], 
                                 "hop_2_rel": this_shard[i]['hop_2_rel']})

        else:
            inps, ans = prepare_shards_llama(this_shard)
            answer_gen = pipe(inps,max_length=2048, batch_size = bsize)
            for i in range(len(ans)):
                correct += int(ans[i].lower() in answer_gen[i][0]['generated_text'][-1]['content'].split("ANSWER: ")[-1][-40:].lower())
                total+= 1
                records.append({'question':this_shard[i][['question']],
                                 "correct_response":ans[i], 
                                 "answered_correctly": bool(int(ans[i].lower() in answer_gen[i][0]['generated_text'][-1]['content'].split("ANSWER: ")[-1][-40:].lower())),
                                 "model output":answer_gen[i][0]['generated_text'][-1]['content'],
                                 "extracted_answer": answer_gen[i][0]['generated_text'][-1]['content'].split("ANSWER: ")[-1][-40:].lower(),
                                 "relevant":this_shard[i]['relevant_cf'],
                                 "cf_hop":this_shard[i]['relevant_cf'], 
                                 "hop_1_rel": this_shard[i]['hop_1_rel'], 
                                 "hop_2_rel": this_shard[i]['hop_2_rel']})
            
    return float(correct)/float(total), records
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_or_path", type = str,default = '/data/models/huggingface/meta-llama/Llama-3.1-8B-Instruct')
    parser.add_argument("--data_path", type = str, default = './wikidata_cf_split')
    parser.add_argument("--n_shards", type = int, default = 64)
    parser.add_argument("--batch_size", type = int, default = 16)
    parser.add_argument("--save_file", type = str, default = None)
    args = parser.parse_args()
    wiki_cf_ds = datasets.DatasetDict.load_from_disk(args.data_path)['test']
    pipeline = get_pipeline(args.model_name_or_path)
    is_deepseek = bool("DeepSeek" in args.model_name_or_path)
    acc, records = evaluate_cf(pipeline, wiki_cf_ds, args.n_shards,deepseek=is_deepseek, bsize = args.batch_size)
    with open(args.save_file, 'wb') as save_file:
        pickle.dump(records, save_file)
    


