from datasets import load_dataset
from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM
import traceback
import json
import re
import numpy as np
import signal
import torch
from io import StringIO
import sys
import io
import matplotlib.pyplot as plt
import seaborn

#load model
model_name_or_path_chat = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path_chat, torch_dtype=torch.float16, device_map="auto").eval()
use_fast_tokenizer = "LlamaForCausalLM" not in model.config.architectures
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path_chat, use_fast=use_fast_tokenizer, padding_side="left",
                                          legacy=False)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
tokenizer.bos_token_id = 1
print("load model finished!")

#load dataset
code_contests_dataset = load_dataset('deepmind/code_contests')
ind = [29,60,63,75,98] #problem indices to use for compoisition
all_results = [] #list for all measurements of noise

#iterate over pairs of problems
for i in ind:
    print(i)
    for j in ind:
        if i != j:
            print('_' + str(j))
            index1 = i
            index2 = j
            problem_description_1 = code_contests_dataset['test']['description'][i]
            problem_description_2 = code_contests_dataset['test']['description'][j]
            prefix = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>"
            suffix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"

            #construct prompt for x1, x2 and x1+x2
            request1 = "Solve the following problem in python:\n\n" + problem_description_1 + "\n\nPlease write the entire solution in one block of code."
            request2 = "Solve the following problem in python:\n\n" + problem_description_2 + "\n\nPlease write the entire solution in one block of code."
            request_comb = "Solve the two following problems in python:\n\nProblem 1:\n" + problem_description_1 + "\n\nProblem 2:\n" + problem_description_2 + "\n\nPlease write the solution to both problems in one block of code, the program should receive the inputs to the problems sequentially and output the concatenated outputs to the problems."
            prompt1 = f"{prefix}{request1}{suffix}"
            prompt2 = f"{prefix}{request2}{suffix}"
            prompt_comb = f"{prefix}{request_comb}{suffix}"

            #solutions for x1,x2
            solutions_1 = [y for (x, y) in zip(code_contests_dataset['test']['solutions'][i]['language'],
                                               code_contests_dataset['test']['solutions'][i]['solution']) if x == 3 and len(tokenizer.encode(y)) < 100]
            solutions_2 = [y for (x, y) in zip(code_contests_dataset['test']['solutions'][j]['language'],
                                               code_contests_dataset['test']['solutions'][j]['solution']) if x == 3 and len(tokenizer.encode(y)) < 100]
            multi_sol_diffs_1 = []
            multi_sol_diffs_2 = []

            #iterate over different solutions
            for k in range(min(len(solutions_1), len(solutions_2), 10)):
                #reformat solutions to template
                splitted_lines_1 = solutions_1[k].split('\n')
                sol_1 = "\n    ".join(splitted_lines_1)
                splitted_lines_2 = solutions_2[k].split('\n')
                sol_2 = "\n    ".join(splitted_lines_2)
                len_tokens_1 = len(tokenizer.encode(sol_1)) - 10 #number of tokens in first solution
                len_tokens_2 = len(tokenizer.encode(sol_2)) - 10 #number of tokens in second solution


                q1_encoding = tokenizer.encode_plus(
                    prompt1 + "Here is a solution to the problem:\n\n```\ndef problem():\n    " + sol_1,
                    return_tensors="pt",
                    padding=True, truncation=True)
                q2_encoding = tokenizer.encode_plus(
                    prompt2 + "Here is a solution to the problem:\n\n```\ndef problem():\n    " + sol_2,
                    return_tensors="pt", padding=True, truncation=True)

                q_comb1_encoding = tokenizer.encode_plus(
                    prompt_comb + "Here is a solution to the problems:\n\n```\ndef problem1():\n    " + sol_1,
                    return_tensors="pt", padding=True, truncation=True)
                q_comb2_encoding = tokenizer.encode_plus(
                    prompt_comb + "Here is a solution to the problems:\n\n```\ndef problem1():\n    " + sol_1 + "\ndef problem2():\n    " + sol_2,
                    return_tensors="pt", padding=True, truncation=True)


                input_ids1 = q1_encoding['input_ids'].to('cuda')
                input_ids2 = q2_encoding['input_ids'].to('cuda')
                input_ids_comb1 = q_comb1_encoding['input_ids'].to('cuda')
                input_ids_comb2 = q_comb2_encoding['input_ids'].to('cuda')

                #extract logit diffs for first problem (with vs without composition)
                with torch.no_grad():
                    logits1 = model(input_ids1).logits[0][-len_tokens_1:]
                    logits_comb1 = model(input_ids_comb1).logits[0][-len_tokens_1:]

                #take top 100 tokens in probability mass, to avoid numerical errors
                sorted_indices_1 = [logits1[i].sort(0).indices[-100:] for i in range(len_tokens_1 - 1)]
                #difference in logits (with vs without composition)
                diff_1 = [logits_comb1[i][sorted_indices_1[i]] - logits1[i][sorted_indices_1[i]] for i in
                          range(len_tokens_1 - 1)]
                diff_correct_1 = [logits_comb1[i][input_ids1[0][-len_tokens_1 + i + 1]] - logits1[i][
                    input_ids1[0][-len_tokens_1 + i + 1]] for i in range(len_tokens_1 - 1)]

                #extract probability of tokens without composition
                log_p_1 = torch.nn.functional.log_softmax(logits1, dim=1)
                weights_1 = torch.exp(log_p_1.sort(1).values[:, -100:])
                weighted_centers_1 = [(diff_1[i][diff_1[i] - diff_correct_1[i] != 0] * weights_1[i][
                    diff_1[i] - diff_correct_1[i] != 0]).sum() / weights_1[i][diff_1[i] - diff_correct_1[i] != 0].sum()
                                      for i in range(len_tokens_1 - 1)]
                #calculate the weighted logit noise
                diff_correct_centered_weighted_1 = torch.tensor(
                    [x - y for x, y, w in zip(diff_correct_1, weighted_centers_1,weights_1)])
                multi_sol_diffs_1 += [round(x, 8) for x in
                                      torch.Tensor.tolist(diff_correct_centered_weighted_1)]#[:min(20, len_tokens_1)]

                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    torch.cuda.synchronize()

                # extract logit diffs for second problem (with vs without composition)
                with torch.no_grad():
                    logits2 = model(input_ids2).logits[0][-len_tokens_2:]
                    logits_comb2 = model(input_ids_comb2).logits[0][-len_tokens_2:]

                # take top 100 tokens in probability mass, to avoid numerical errors
                sorted_indices_2 = [logits2[i].sort(0).indices[-100:] for i in range(len_tokens_2 - 1)]
                # difference in logits (with vs without composition)
                diff_2 = [logits_comb2[i][sorted_indices_2[i]] - logits2[i][sorted_indices_2[i]] for i in
                          range(len_tokens_2 - 1)]
                diff_correct_2 = [logits_comb2[i][input_ids2[0][-len_tokens_2 + i + 1]] - logits2[i][
                    input_ids2[0][-len_tokens_2 + i + 1]] for i in range(len_tokens_2 - 1)]

                # take top 100 tokens in probability mass, to avoid numerical errors
                log_p_2 = torch.nn.functional.log_softmax(logits2, dim=1)
                weights_2 = torch.exp(log_p_2.sort(1).values[:, -100:])
                weighted_centers_2 = [(diff_2[i][diff_2[i] - diff_correct_2[i] != 0] * weights_2[i][
                    diff_2[i] - diff_correct_2[i] != 0]).sum() / weights_2[i][diff_2[i] - diff_correct_2[i] != 0].sum()
                                      for i in range(len_tokens_2 - 1)]
                # difference in logits
                diff_correct_centered_weighted_2 = torch.tensor(
                    [x - y for x, y, w in zip(diff_correct_2, weighted_centers_2,weights_2)])
                multi_sol_diffs_2 += [round(x, 8) for x in
                                      torch.Tensor.tolist(diff_correct_centered_weighted_2)]#[:min(20, len_tokens_2)]

                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    torch.cuda.synchronize()
            all_results.append(multi_sol_diffs_1)
            all_results.append(multi_sol_diffs_2)


#concatenate results to obtain distribution of logit noise
concatenated_results = []
for x in all_results:
    concatenated_results+= x

#plot density of logit noise
seaborn.kdeplot(concatenated_results)
plt.show()
