from datasets import load_dataset
from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM
import traceback
import json
import re
import numpy as np
import signal
import torch
from io import StringIO
import sys
import io
import matplotlib.pyplot as plt


def sequence_probability(x):
    # calcualte probability of a sequence

    xtokens = tokenizer(x)['input_ids']
    total = torch.unsqueeze(torch.tensor(xtokens), dim=0)

    with torch.no_grad():
        logits = model(input_ids=total.cuda()).logits
        logp = torch.nn.functional.log_softmax(logits,dim=-1)
    log_P = [0]
    for k in range(1,len(xtokens)):
        log_P.append(log_P[k - 1] + logp[0,k-1, xtokens[k]])

    return [x.item() for x in log_P[1:]]

#load model
model_name_or_path_chat = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path_chat, torch_dtype=torch.float16, device_map="auto").eval()
use_fast_tokenizer = "LlamaForCausalLM" not in model.config.architectures
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path_chat, use_fast=use_fast_tokenizer, padding_side="left", legacy=False)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
tokenizer.bos_token_id = 1
print("load model finished!")

#load dataset
humaneval_dataset = load_dataset('openai/openai_humaneval')

ind = [69,78,85,90,98,99,114,133] #problem indices to use for compoisition
prob_diffs = [] #list for all measurements probability difference (with vs without composition)

#iterate over pairs of problems
for index1 in ind:
    print(index1)
    for index2 in ind:
        if index1 != index2:

            problem_description_1 = humaneval_dataset['test'][index1]['prompt']
            problem_description_2 = humaneval_dataset['test'][index2]['prompt']
            prefix = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>"
            suffix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"

            # construct prompt for x1, x2 and x1+x2
            request1 = "Complete the following function in python:\n" + problem_description_1
            request2 = "Complete the following function in python:\n" + problem_description_2
            request_comb = "Complete the two following functions in python:\n\nFunction 1:\n" + problem_description_1 + "\n\nFunction 2:\n" + problem_description_2 + "Write code that computes and prints the product of outputs of these functions."
            prompt1 = f"{prefix}{request1}{suffix}"
            prompt2 = f"{prefix}{request2}{suffix}"
            prompt_comb = f"{prefix}{request_comb}{suffix}"

            # solutions for x1,x2
            problem1 = humaneval_dataset['test'][index1]['prompt'].split(':')[0] + ":\n" + humaneval_dataset['test'][index1][
                'canonical_solution']
            problem2 = humaneval_dataset['test'][index2]['prompt'].split(':')[0] + ":\n" + humaneval_dataset['test'][index2][
                'canonical_solution']
            len_tokens_1 = len(tokenizer.encode(humaneval_dataset['test'][index1]['canonical_solution'])) - 5 #number of tokens in first solution
            len_tokens_2 = len(tokenizer.encode(humaneval_dataset['test'][index2]['canonical_solution'])) - 5 #number of tokens in second solution

            # calculate probability of y2 (with vs without composition)
            seq_1 = sequence_probability(prompt2+"Here is a solution for the function:\n\n```"+problem2)[-len_tokens_2:]
            seq_2 = sequence_probability(prompt_comb+"Here are the solutions for the functions:\n\n```"+problem1+problem2)[-len_tokens_2:]

            # normalize probability, to start measuring probability at beginning of sequence
            x = np.array(seq_1)
            x = x-x[0]
            y = np.array(seq_2)
            y = y-y[0]

            prob_diffs.append(x-y)

            # calculate probability of y1 (with vs without composition)
            seq_1 = sequence_probability(prompt1+"Here is a solution for the function:\n\n```"+problem1)[-len_tokens_1:]
            seq_2 = sequence_probability(prompt_comb+"Here are the solutions for the functions:\n\n```"+problem1)[-len_tokens_1:]

            # normalize probability, to start measuring probability at beginning of sequence
            x = np.array(seq_1)
            x = x-x[0]
            y = np.array(seq_2)
            y = y-y[0]

            prob_diffs.append(x-y)


#plot average sequence probability difference (with vs without composition) as a function of length
min_length = min([len(x) for x in prob_diffs])
prob_diffs_trimmed = np.array([x[:min_length] for x in prob_diffs])
avg=np.mean(prob_diffs_trimmed,axis=0)
err = np.std(prob_diffs_trimmed,axis=0)/np.sqrt(len(prob_diffs_trimmed))
plt.plot(list(range(min_length)),avg)
plt.fill_between(list(range(min_length)), avg - err, avg + err, alpha=0.2)
plt.show()

