# modified from https://github.com/jennyzzt/LLM_debate_on_ARC
# prompt also inspired by https://github.com/rgreenblatt/arc_draw_more_samples_pub/blob/master/arc_solve/prompting.py   

import concurrent.futures
import random
import string
import json
import pickle
import numpy as np
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

TASK_OVERVIEW = """You will be given some number of paired example inputs and outputs grids. The outputs were produced by applying a transformation rule to the input grids. In addition to the paired example inputs and outputs, there is also one test input without a known output.
The inputs and outputs are each "grids". A grid is a rectangular matrix of integers between 0 and 9 (inclusive). Each number corresponds to a color. 0 is black.
Your task is to determine the transformation rule from examples and find out the answer, involving determining the size of the output grid for the test and correctly filling each cell of the grid with the appropriate color or number.

The transformation only needs to be unambiguous and applicable to the example inputs and the test input. It doesn't need to work for all possible inputs. Observe the examples carefully, imagine the grid visually, and try to find the pattern.
"""


def original_solver(task_input):
    prompt = f"""# Your Task:
{task_input}

# Instruction: 
Please think step by step and then solve the task by writing the code.

You will write code to solve this task by creating a function named `transform`. This function should take a single argument, the input grid as `list[list[int]]`, and returns the transformed grid (also as `list[list[int]]`). You should make sure that you implement a version of the transformation that works for both example and test inputs. Make sure that the transform function is capable of handling both example and test inputs effectively, reflecting the learned transformation rules from the Examples inputs and outputs."""
    system_prompt = """You are a helpful assistant.

# Output Format:
Reply EXACTLY with the following JSON format.
{'thinking': 'Your thinking.', 'code': "Your code. Don't write tests in your Python code, ONLY return the `transform` function. DO NOT return anything else. (It will be tested later.)"}
DO NOT MISS ANY REQUEST FIELDS and ensure that your response is a WELL-FORMED JSON object!"""
    try:
        response = self.client.chat.completions.create(
            model='gpt-4o-mini',
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": prompt}
            ],
            temperature=0.8, max_tokens=1024, stop=None, response_format={"type": "json_object"}
        )
        content = response.choices[0].message.content
        json_dict = json.loads(content)
        return json_dict['code']
    except Exception as e:
        raise e





class ARC_Task:
    def evaluate(self, solver):
        arc_dir = "../copied/ADAS/dataset/sampled_arc_val_data.pkl"
        with open(arc_dir, 'rb') as pickle_file:
            arc_data_queue = pickle.load(pickle_file)
        max_workers = min(len(arc_data_queue), 32)

        agent_task_queue = []
        for arc_data in arc_data_queue:
            task_str, examples, test_input = format_arc_data(arc_data)
            taskInfo = task_str
            agent_task_queue.extend([(solver, taskInfo, arc_data, test_input)])

        def call_forward(agent_task_queue):
            solver, taskInfo, arc_data, test_input = agent_task_queue
            code = solver(taskInfo)

            def get_test_output_from_code(test_input, code):

                gen_output = lambda msg: msg

                local_vars = {}
                try:
                    exec(code, {}, local_vars)
                except Exception as e:
                    return gen_output(f"Error during code execution: {e}")
                if 'transform' not in local_vars:
                    return gen_output("Function 'transform' not found in the code.")

                transform = local_vars['transform']
                try:
                    transform_output = transform(test_input)
                    transform_output = list_to_string(transform_output)
                except Exception as e:
                    return gen_output("Error during function execution: {e}")

                return gen_output(transform_output)
            res = get_test_output_from_code(test_input, code)
            try:
                res = eval(res)
                hard_score = eval_solution(res, arc_data, soft_eval=False)
                return hard_score
            except Exception as e:
                return 0

        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            acc_list = list(tqdm(executor.map(call_forward, agent_task_queue), total=len(agent_task_queue)))

        feedback = f"acc: {sum(acc_list) / len(acc_list)}"
        return feedback, ""




def random_id(length=4):
    characters = string.ascii_letters + string.digits  # includes both upper/lower case letters and numbers
    random_id = ''.join(random.choices(characters, k=length))
    return random_id


def file_to_string(filepath):
    with open(filepath, 'r') as f:
        data = f.read().strip()
    return data


def list_to_string(list_2d):
    sublists_as_strings = [f"[{','.join(map(str, sublist))}]" for sublist in list_2d]
    return f"[{','.join(sublists_as_strings)}]"


def format_arc_data(arc_data, direct=False):
    task_str = TASK_OVERVIEW

    task_demo_str = ''
    # Get task demo string
    task_demo_str += '## Examples:\n\n'
    for i, demo in enumerate(arc_data['train']):
        task_demo_str += f'### Example {i}:\n'
        task_demo_str += f'input = {list_to_string(demo["input"])}\n'
        task_demo_str += f'output = {list_to_string(demo["output"])}\n\n'

    # Get task test string
    task_test_str = ''
    for testcase in arc_data['test']:
        task_test_str += '## Test Problem:\n'
        task_test_str += f'Given input:\n {list_to_string(testcase["input"])}\n\n'
        task_test_str += f'Analyze the transformation rules based on the provided Examples and determine what the output should be for the Test Problem.'

    task_str += task_demo_str + task_test_str

    return task_str, arc_data['train'], arc_data['test'][0]['input']


def get_percentage_match(arr1, arr2):
    # arr1 is solution
    if not arr2:
        return 0
    score = 0
    for i, xs in enumerate(arr1):
        try:
            for j, x in enumerate(xs):
                try:
                    if len(arr2) > i and len(arr2[i]) > j and arr2[i][j] == x:
                        score += 1
                except:
                    pass
        except:
            pass
    score = score / (len(arr1) * len(arr1[0]))
    return score


def eval_algo(solve_fn, arc_data, soft_eval=False):
    # Calculate percentage of test cases done correctly
    testcases = arc_data['test']
    scores = []
    for testcase in testcases:
        input = testcase['input']
        output = testcase['output']
        gen_output = None
        # Run solve_fn with timeout
        with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
            try:
                future = executor.submit(solve_fn, input)
                try:
                    gen_output = future.result(timeout=30)
                except concurrent.futures.TimeoutError:
                    future.cancel()
            except:  # if the function does not work
                continue
        # Check if correct output
        if soft_eval:
            score = get_percentage_match(output, gen_output)
        else:
            score = 1 if output == gen_output else 0
        scores.append(score)
    return np.mean(scores)


def eval_solution(output, arc_data, soft_eval=False):
    if not output:
        return 0

    solution = arc_data['test'][0]['output']
    if soft_eval:
        score = get_percentage_match(solution, output)
    else:
        score = 1 if output == solution else 0
    return score


def bootstrap_confidence_interval(data, num_bootstrap_samples=100000, confidence_level=0.95):
    """
    Calculate the bootstrap confidence interval for the mean of 1D accuracy data.
    Also returns the median of the bootstrap means.
    
    Args:
    - data (list or array of float): 1D list or array of data points.
    - num_bootstrap_samples (int): Number of bootstrap samples.
    - confidence_level (float): The desired confidence level (e.g., 0.95 for 95%).
    
    Returns:
    - str: Formatted string with 95% confidence interval and median as percentages with one decimal place.
    """
    # Convert data to a numpy array for easier manipulation
    data = np.array(data)

    # List to store the means of bootstrap samples
    bootstrap_means = []

    # Generate bootstrap samples and compute the mean for each sample
    for _ in range(num_bootstrap_samples):
        # Resample with replacement
        bootstrap_sample = np.random.choice(data, size=len(data), replace=True)
        # Compute the mean of the bootstrap sample
        bootstrap_mean = np.mean(bootstrap_sample)
        bootstrap_means.append(bootstrap_mean)

    # Convert bootstrap_means to a numpy array for percentile calculation
    bootstrap_means = np.array(bootstrap_means)

    # Compute the lower and upper percentiles for the confidence interval
    lower_percentile = (1.0 - confidence_level) / 2.0
    upper_percentile = 1.0 - lower_percentile
    ci_lower = np.percentile(bootstrap_means, lower_percentile * 100)
    ci_upper = np.percentile(bootstrap_means, upper_percentile * 100)

    # Compute the median of the bootstrap means
    median = np.median(bootstrap_means)

    # Convert to percentages and format to one decimal place
    ci_lower_percent = ci_lower * 100
    ci_upper_percent = ci_upper * 100
    median_percent = median * 100

    # Return the formatted string with confidence interval and median
    return f"95% Bootstrap Confidence Interval: ({ci_lower_percent:.1f}%, {ci_upper_percent:.1f}%), Median: {median_percent:.1f}%"
