from datasets import load_dataset
from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM
import traceback
import json
import re
import numpy as np
import signal
import torch
from io import StringIO
import sys
import io
import matplotlib.pyplot as plt


def sequence_probability(x):
    # calcualte probability of a sequence

    xtokens = tokenizer(x)['input_ids']
    total = torch.unsqueeze(torch.tensor(xtokens), dim=0)

    with torch.no_grad():
        logits = model(input_ids=total.cuda()).logits
        logp = torch.nn.functional.log_softmax(logits,dim=-1)
    log_P = [0]
    for k in range(1,len(xtokens)):
        log_P.append(log_P[k - 1] + logp[0,k-1, xtokens[k]])

    return [x.item() for x in log_P[1:]]


#load model
model_name_or_path_chat = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path_chat, torch_dtype=torch.float16, device_map="auto").eval()
use_fast_tokenizer = "LlamaForCausalLM" not in model.config.architectures
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path_chat, use_fast=use_fast_tokenizer, padding_side="left", legacy=False)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
tokenizer.bos_token_id = 1
print("load model finished!")

#load dataset
code_contests_dataset = load_dataset('deepmind/code_contests')

ind = [29,60,63,75,98] #problem indices to use for compoisition
prob_diffs = [] #list for all measurements probability difference (with vs without composition)

#iterate over pairs of problems
for i in ind:
    print(i)
    for j in ind:
        if i != j:
            problem_description_1 = code_contests_dataset['test']['description'][i]
            problem_description_2 = code_contests_dataset['test']['description'][j]
            prefix = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>"
            suffix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"

            # construct prompt for x1, x2 and x1+x2
            request1 = "Solve the following problem in python:\n\n" + problem_description_1 + "\n\nPlease write the entire solution in one block of code."
            request2 = "Solve the following problem in python:\n\n" + problem_description_2 + "\n\nPlease write the entire solution in one block of code."
            request_comb = "Solve the two following problems in python:\n\nProblem 1:\n"+problem_description_1 +"\n\nProblem 2:\n"+problem_description_2+"\n\nPlease write the solution to both problems in one block of code, the program should receive the inputs to the problems sequentially and output the concatenated outputs to the problems."
            prompt1 = f"{prefix}{request1}{suffix}"
            prompt2 = f"{prefix}{request2}{suffix}"
            prompt_comb = f"{prefix}{request_comb}{suffix}"

            # solutions for x1,x2
            solutions_1 = [y for (x, y) in zip(code_contests_dataset['test']['solutions'][i]['language'],
                                               code_contests_dataset['test']['solutions'][i]['solution']) if x == 3 and len(tokenizer.encode(y))<100]
            solutions_2 = [y for (x, y) in zip(code_contests_dataset['test']['solutions'][j]['language'],
                                               code_contests_dataset['test']['solutions'][j]['solution']) if x == 3 and len(tokenizer.encode(y))<100]

            # iterate over different solutions
            for k in range(min(len(solutions_1),len(solutions_2),10)):
                # reformat solutions to template
                splitted_lines_1 = solutions_1[k].split('\n')
                sol_1 = "\n    ".join(splitted_lines_1)
                splitted_lines_2 = solutions_2[k].split('\n')
                sol_2 = "\n    ".join(splitted_lines_2)
                len_tokens_1 = len(tokenizer.encode(sol_1)) - 10 #number of tokens in first solution
                len_tokens_2 = len(tokenizer.encode(sol_2)) - 10 #number of tokens in second solution

                #calculate probability of y2 (with vs without composition)
                seq_1 = sequence_probability(prompt2+"Here is a solution to the problem:\n\n```\ndef problem():\n    "+sol_2)[-len_tokens_2:]
                seq_2 = sequence_probability(prompt_comb + "Here is a solution to the problems:\n\n```\ndef problem1():\n    " + sol_1 + "\ndef problem2():\n    " + sol_2)[-len_tokens_2:]
                #normalize probability, to start measuring probability at beginning of sequence
                x = np.array(seq_1)
                x = x-x[0]
                y = np.array(seq_2)
                y = y-y[0]
                prob_diffs.append(x-y)
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    torch.cuda.synchronize()

                # calculate probability of y1 (with vs without composition)
                seq_1 = sequence_probability(prompt1+"Here is a solution to the problem:\n\n```\ndef problem():\n    "+sol_1)[-len_tokens_1:]
                seq_2 = sequence_probability(prompt_comb + "Here is a solution to the problems:\n\n```\ndef problem1():\n    " + sol_1)[-len_tokens_1:]

                # normalize probability, to start measuring probability at beginning of sequence
                x = np.array(seq_1)
                x = x-x[0]
                y = np.array(seq_2)
                y = y-y[0]
                prob_diffs.append(x-y)
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    torch.cuda.synchronize()


#plot average sequence probability difference (with vs without composition) as a function of length
min_length = min([len(x) for x in prob_diffs])
avg = np.mean(np.array([x[:min_length] for x in prob_diffs]),axis=0)
err = np.std(np.array([x[:min_length] for x in prob_diffs]),axis=0)/np.sqrt(len(avg))
plt.plot(list(range(min_length)),avg)
plt.fill_between(list(range(min_length)), avg - err, avg + err, alpha=0.2)
plt.show()

