from datasets import load_dataset
from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM
import traceback
import json
import re
import numpy as np
import signal
import torch
from io import StringIO
import sys
import io
import matplotlib.pyplot as plt
import seaborn

#load model
model_name_or_path_chat = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path_chat, torch_dtype=torch.float16, device_map="auto").eval()
use_fast_tokenizer = "LlamaForCausalLM" not in model.config.architectures
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path_chat, use_fast=use_fast_tokenizer, padding_side="left", legacy=False)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
tokenizer.bos_token_id = 1
print("load model finished!")

#load dataset
humaneval_dataset = load_dataset('openai/openai_humaneval')
ind = [69,78,85,90,98,99,114,133] #problem indices to use for compoisition
all_results_weighted = [] #list for all measurements of noise

#iterate over pairs of problems
for i in ind:
    print(i)
    for j in ind:
        if i!=j:
            index1=i
            index2=j
            problem_description_1 = humaneval_dataset['test'][index1]['prompt']
            problem_description_2 = humaneval_dataset['test'][index2]['prompt']
            prefix = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>"
            suffix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"

            # construct prompt for x1, x2 and x1+x2
            request1 = "Complete the following function in python:\n" + problem_description_1
            request2 = "Complete the following function in python:\n" + problem_description_2
            request_comb = "Complete the two following functions in python:\n\nFunction 1:\n" + problem_description_1 + "\n\nFunction 2:\n" + problem_description_2 + "Write code that computes and prints the product of outputs of these functions."
            prompt1 = f"{prefix}{request1}{suffix}"
            prompt2 = f"{prefix}{request2}{suffix}"
            prompt_comb = f"{prefix}{request_comb}{suffix}"

            # solutions for x1,x2
            problem1 = humaneval_dataset['test'][index1]['prompt'].split(':')[0] + ":\n" + humaneval_dataset['test'][index1][
                'canonical_solution']
            problem2 = humaneval_dataset['test'][index2]['prompt'].split(':')[0] + ":\n" + humaneval_dataset['test'][index2][
                'canonical_solution']
            len_tokens_1 = len(tokenizer.encode(humaneval_dataset['test'][index1]['canonical_solution'])) - 5 #number of tokens in first solution
            len_tokens_2 = len(tokenizer.encode(humaneval_dataset['test'][index2]['canonical_solution'])) - 5 #number of tokens in second solution

            q1_encoding = tokenizer.encode_plus(prompt1+"Here is a solution to the function:\n\n```"+problem1, return_tensors="pt", padding=True, truncation=True)
            q2_encoding = tokenizer.encode_plus(prompt2+"Here is a solution to the function:\n\n```"+problem2, return_tensors="pt", padding=True, truncation=True)
            q_comb1_encoding = tokenizer.encode_plus(prompt_comb+"Here is a solution to the functions:\n\n```"+problem1, return_tensors="pt", padding=True, truncation=True)
            q_comb2_encoding = tokenizer.encode_plus(prompt_comb+"Here is a solution to the functions:\n\n```"+problem1+problem2, return_tensors="pt", padding=True, truncation=True)

            input_ids1 = q1_encoding['input_ids'].to('cuda')
            input_ids2 = q2_encoding['input_ids'].to('cuda')
            input_ids_comb1 = q_comb1_encoding['input_ids'].to('cuda')
            input_ids_comb2 = q_comb2_encoding['input_ids'].to('cuda')

            # extract logit diffs for first problem (with vs without composition)
            with torch.no_grad():
                logits1 = model(input_ids1).logits[0][-len_tokens_1:]
                logits_comb1 = model(input_ids_comb1).logits[0][-len_tokens_1:]

            # take top 100 tokens in probability mass, to avoid numerical errors
            sorted_indices_1 = [logits1[i].sort(0).indices[-100:] for i in range(len_tokens_1-1)]
            # difference in logits (with vs without composition)
            diff_1 = [logits_comb1[i][sorted_indices_1[i]] - logits1[i][sorted_indices_1[i]] for i in range(len_tokens_1-1)]
            diff_correct_1 = [logits_comb1[i][input_ids1[0][-len_tokens_1+i+1]] - logits1[i][input_ids1[0][-len_tokens_1+i+1]] for i in range(len_tokens_1-1)]

            # extract probability of tokens without composition
            log_p_1 = torch.nn.functional.log_softmax(logits1,dim=1)
            weights_1 = torch.exp(log_p_1.sort(1).values[:,-100:])
            weighted_centers_1 = [(diff_1[i][diff_1[i]-diff_correct_1[i]!=0]*weights_1[i][diff_1[i]-diff_correct_1[i]!=0]).sum()/weights_1[i][diff_1[i]-diff_correct_1[i]!=0].sum() for i in range(len_tokens_1-1)]
            # calculate the weighted logit noise
            diff_correct_centered_weighted_1 = torch.tensor([x-y for x,y,w in zip(diff_correct_1,weighted_centers_1,weights_1)]) # if w[-1]<=0.95
            all_results_weighted.append(torch.Tensor.tolist(diff_correct_centered_weighted_1))

            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()

            # extract logit diffs for second problem (with vs without composition)
            with torch.no_grad():
                logits2 = model(input_ids2).logits[0][-len_tokens_2:]
                logits_comb2 = model(input_ids_comb2).logits[0][-len_tokens_2:]

            # take top 100 tokens in probability mass, to avoid numerical errors
            sorted_indices_2 = [logits2[i].sort(0).indices[-100:] for i in range(len_tokens_2-1)]
            # difference in logits (with vs without composition)
            diff_2 = [logits_comb2[i][sorted_indices_2[i]] - logits2[i][sorted_indices_2[i]] for i in range(len_tokens_2-1)]
            diff_correct_2 = [logits_comb2[i][input_ids2[0][-len_tokens_2+i+1]] - logits2[i][input_ids2[0][-len_tokens_2+i+1]] for i in range(len_tokens_2-1)]

            # extract probability of tokens without composition
            log_p_2 = torch.nn.functional.log_softmax(logits2,dim=1)
            weights_2 = torch.exp(log_p_2.sort(1).values[:,-100:])
            weighted_centers_2 = [(diff_2[i][diff_2[i]-diff_correct_2[i]!=0]*weights_2[i][diff_2[i]-diff_correct_2[i]!=0]).sum()/weights_2[i][diff_2[i]-diff_correct_2[i]!=0].sum() for i in range(len_tokens_2-1)]
            # calculate the weighted logit noise
            diff_correct_centered_weighted_2 = torch.tensor([x-y for x,y,w in zip(diff_correct_2,weighted_centers_2,weights_2)]) # if w[-1]<=0.95
            all_results_weighted.append(torch.Tensor.tolist(diff_correct_centered_weighted_2))

            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()


#concatenate results to obtain distribution of logit noise
concatenated_results = []
for x in all_results_weighted:
    concatenated_results+= x

#plot density of logit noise
seaborn.kdeplot(concatenated_results)
plt.show()

