from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM
import torch
from datasets import load_dataset
from io import StringIO
import sys
import io
import numpy as np
import signal

#load model
model_name_or_path_chat = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path_chat, torch_dtype=torch.float16, device_map="auto").eval()
use_fast_tokenizer = "LlamaForCausalLM" not in model.config.architectures
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path_chat, use_fast=use_fast_tokenizer, padding_side="left", legacy=False)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
tokenizer.bos_token_id = 1
print("load model finished!")

#load code contests dataset
code_contests_dataset = load_dataset('deepmind/code_contests')


def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def solve_problem(problem_index,num_generations,code_contests_dataset):
    problem_description = code_contests_dataset['test']['description'][problem_index]
    prefix = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>"
    suffix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
    request = "Solve the following problem in python:\n\n"+problem_description+"\n\nPlease write the entire solution in one block of code, including reading the inputs and writing the outputs."
    prompt = f"{prefix}{request}{suffix}"
    q_encoding  = tokenizer.encode_plus(prompt, return_tensors="pt", padding=True, truncation=True)
    input_ids = q_encoding['input_ids'].to('cuda')
    attn_mask = q_encoding['attention_mask'].to('cuda')
    with torch.no_grad():
        outputs = model.generate(input_ids, max_new_tokens=2048, temperature=1.0, do_sample=True,
                                top_p=0.95, attention_mask=attn_mask,
                                return_dict_in_generate=True,
                                pad_token_id=tokenizer.pad_token_id,num_return_sequences=num_generations)
    # decode the input only
    partial_given_answers = [tokenizer.decode(output_sequence[:input_ids.shape[1]], skip_special_tokens=True) for output_sequence in outputs.sequences]
    # decode the entire output, and remove the input from it
    answers = [tokenizer.decode(outputs.sequences[i], skip_special_tokens=True).replace(partial_given_answers[i],"").replace('<s>',"").replace('</s>',"") for i in range(len(partial_given_answers))]
    return answers


def solve_combined_problem(problem_index_1,problem_index_2,num_generations,code_contests_dataset):
    problem_description_1 = code_contests_dataset['test']['description'][problem_index_1]
    problem_description_2 = code_contests_dataset['test']['description'][problem_index_2]
    prefix = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>"
    suffix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
    #request = "Solve the two following problems in python:\n\nProblem 1:\n"+problem_description_1 +"\n\nProblem 2:\n"+problem_description_2+"\n\nPlease write the solution to both problems in one block of code, the program should receive the inputs to the problems sequentially and output the concatenated outputs to the problems. Write a function to solve the first problem, then a function to solve the second problem, then apply the functions one after the other."
    request = "Solve the two following problems in python:\n\nProblem 1:\n"+problem_description_1 +"\n\nProblem 2:\n"+problem_description_2+"\n\nPlease write the solution to both problems in one block of code, the program should receive the inputs to the problems sequentially and output the concatenated outputs to the problems."

    prompt = f"{prefix}{request}{suffix}"
    q_encoding  = tokenizer.encode_plus(prompt, return_tensors="pt", padding=True, truncation=True)
    input_ids = q_encoding['input_ids'].to('cuda')
    attn_mask = q_encoding['attention_mask'].to('cuda')
    with torch.no_grad():
        outputs = model.generate(input_ids, max_new_tokens=2048, temperature=1.0, do_sample=True,
                                top_p=0.95, attention_mask=attn_mask,
                                return_dict_in_generate=True,
                                pad_token_id=tokenizer.pad_token_id,num_return_sequences=num_generations)
    # decode the input only
    partial_given_answers = [tokenizer.decode(output_sequence[:input_ids.shape[1]], skip_special_tokens=True) for output_sequence in outputs.sequences]
    # decode the entire output, and remove the input from it
    answers = [tokenizer.decode(outputs.sequences[i], skip_special_tokens=True).replace(partial_given_answers[i],"").replace('<s>',"").replace('</s>',"") for i in range(len(partial_given_answers))]
    return answers


def eval_generated_code(code,problem_index,code_contests_dataset):
    def handler(signum,frame):
        raise TimeoutError("Code execution exceeded the time limit")

    test_inputs = code_contests_dataset['test']['public_tests'][problem_index]['input'][0]
    test_outputs = code_contests_dataset['test']['public_tests'][problem_index]['output'][0]
    sys.stdin = StringIO(test_inputs)
    # Create an in-memory buffer
    buffer = io.StringIO()

    # Redirect stdout to the buffer
    original_stdout = sys.stdout
    sys.stdout = buffer

    # Set the signal handler for the SIGALRM signal
    signal.signal(signal.SIGALRM,handler)
    timeout_duration = 5
    signal.alarm(timeout_duration)
    try:
        # Execute the code
        exec(code)
    except TimeoutError as e:
        print(e)
    finally:
        # Cancel the alarm if code execution completes within the time limit
        signal.alarm(0)
        # Reset stdout to its original state
        sys.stdout = original_stdout

        # Get the content of the buffer
        code_outputs = buffer.getvalue()
        sys.stdin = sys.__stdin__
        return code_outputs == test_outputs
    return False

def evaluate_all_generated_codes(total_codes,problem_index,code_contests_dataset):
    passes = 0
    for code in total_codes:
        try:
            result = eval_generated_code(code, problem_index, code_contests_dataset)
        except:
            result = False
        passes += result

    pass_rate = passes / total_generations
    print("Problem: " +str(problem_index) +", Pass rate: " + str(pass_rate))
    return pass_rate



def generate_all_codes(total_generations,num_generations_batch,problem_index,code_contests_dataset):
    total_codes = []

    for i in range(total_generations // num_generations_batch):
        # sample responses
        answers = solve_problem(problem_index, num_generations_batch, code_contests_dataset)
        codes = [answer.split("```")[1].replace("python\n", "").replace("Python\n", "").replace("python", "") for answer in answers]
        total_codes += codes
        print("completed " + str(i+1) + " of " + str(total_generations // num_generations_batch))
    return total_codes



def generate_all_codes_combined(total_generations,num_generations_batch,problem_index_1,problem_index_2,code_contests_dataset):
    total_codes = []

    for i in range(total_generations // num_generations_batch):
        # sample responses
        answers = solve_combined_problem(problem_index_1,problem_index_2, num_generations_batch, code_contests_dataset)
        for answer in answers:
            try:
                total_codes.append(answer.split("```")[1].replace("python\n", "").replace("Python\n", "").replace("python", ""))
            except:
                total_codes.append("")
        print("completed " + str(i+1) + " of " + str(total_generations // num_generations_batch))
    return total_codes

def evaluate_all_generated_codes_combined(total_codes,problem_index_1,problem_index_2,code_contests_dataset):
    passes = 0
    for codes in total_codes:
        try:
            result = eval_generated_code_combined(codes, problem_index_1, problem_index_2, code_contests_dataset)
        except:
            result = False
        passes += result

    pass_rate = passes / total_generations
    print("Problem: " +str(problem_index_1) +"+"+str(problem_index_2) +", Pass rate: " + str(pass_rate))
    return pass_rate

def eval_generated_code_combined(code,problem_index_1, problem_index_2,code_contests_dataset):
    def handler(signum,frame):
        raise TimeoutError("Code execution exceeded the time limit")

    test_inputs = code_contests_dataset['test']['public_tests'][problem_index_1]['input'][0]+code_contests_dataset['test']['public_tests'][problem_index_2]['input'][0]
    test_outputs = code_contests_dataset['test']['public_tests'][problem_index_1]['output'][0]+code_contests_dataset['test']['public_tests'][problem_index_2]['output'][0]
    sys.stdin = StringIO(test_inputs)
    # Create an in-memory buffer
    buffer = io.StringIO()

    # Redirect stdout to the buffer
    original_stdout = sys.stdout
    sys.stdout = buffer

    # Set the signal handler for the SIGALRM signal
    signal.signal(signal.SIGALRM,handler)
    timeout_duration = 5
    signal.alarm(timeout_duration)
    try:
        # Execute the code
        exec(code)
    except TimeoutError as e:
        print(e)
    finally:
        # Cancel the alarm if code execution completes within the time limit
        signal.alarm(0)
        # Reset stdout to its original state
        sys.stdout = original_stdout

        # Get the content of the buffer
        code_outputs = buffer.getvalue()
        sys.stdin = sys.__stdin__
        return code_outputs == test_outputs
    return False


############################################# combined problems
total_generations = 200
num_generations_batch = 4

pass_rates = []
torch.manual_seed(45)
all_problems_codes = []
#calculate pass rates of combined problems
for problem_index_1, problem_index_2 in [(29,60),(29,98),(60,29),(60,98),(98,29),(98,60),(63,29),(63,60),(63,98),(29,63),(60,63),(98,63),(75,29),(75,60),(75,98),(29,75),(60,75),(98,75)]:
    total_codes = generate_all_codes_combined(total_generations,num_generations_batch,problem_index_1,problem_index_2,code_contests_dataset)
    print("Composite pass rate:")
    pass_rate = evaluate_all_generated_codes_combined(total_codes,problem_index_1,problem_index_2,code_contests_dataset)
    pass_rates.append(((problem_index_1,problem_index_2),pass_rate))
    all_problems_codes.append(total_codes)
pass_rates_composition = pass_rates

############## standalone evaluation of problems
total_generations = 100
num_generations_batch = 4
torch.manual_seed(43)
standalone_pass_rates = []
all_problems_codes = []
#calculate pass rates for standalone problems
for problem_index in [11,29,35,60,63,75,79,90,92,98]:
    total_codes = generate_all_codes(total_generations,num_generations_batch,problem_index,code_contests_dataset)
    pass_rate = evaluate_all_generated_codes(total_codes,problem_index,code_contests_dataset)
    print(pass_rate)
    standalone_pass_rates.append((problem_index,pass_rate))
    all_problems_codes.append(total_codes)

