import pdb
from tqdm import tqdm
import numpy as np
import random
from PIL import Image, ImageOps
import re, ast
import math, json
from utils.utils import extract_code_blocks, flatten_list, reshape_list
from utils.python_utils import execute_function, extract_program
import signal
import re


class TimeoutException(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutException("Execution timed out!")

def extract_dict(s) -> list:
    """Extract all valid dicts from a string.
    
    Args:
        s (str): A string possibly containing dicts.
    
    Returns:
        A list containing all valid dicts.
    
    """
    results = []
    s_ = ' '.join(s.split('\n')).strip()
    exp = re.compile(r'(\{.*?\})')
    return exp.findall(s_)

def extract_single_dict_from_string(s: str) -> dict:
    # Find the first opening brace
    start = s.find('{')
    if start == -1:
        raise ValueError("No dictionary found in the string.")
    
    bracket_count = 0
    in_string = False
    string_char = ''
    escaped = False

    # Iterate over the string starting at the first '{'
    for i in range(start, len(s)):
        char = s[i]

        if in_string:
            if escaped:
                escaped = False
            elif char == '\\':
                escaped = True
            elif char == string_char:
                in_string = False
        else:
            if char in ('"', "'"):
                in_string = True
                string_char = char
            elif char == '{':
                bracket_count += 1
            elif char == '}':
                bracket_count -= 1
                if bracket_count == 0:
                    # Found the matching closing brace for the outermost dictionary.
                    dict_str = s[start:i+1]
                    try:
                        # Safely evaluate the dictionary string.
                        return ast.literal_eval(dict_str)
                    except:
                        return ""

    return ""

def make_string_literal_single_line(text):
    in_str = False
    string_type = ""
    escaped = False
    result = []

    for ch in text:
        if in_str:
            if escaped:
                escaped = False
            elif ch == '\\':
                escaped = True
            elif ch == string_type:  # End of string literal
                in_str = False
            elif ch == '\n':
                ch = '\\n'  # Replace newlines in string literals
            elif ch == '\'' and string_type == '"':
                ch = "\\'"  # Escape single quotes when string is wrapped in double quotes
        else:
            if ch in ("'", '"'):
                in_str = True
                string_type = ch
        
        result.append(ch)

    return ''.join(result)

def extract_single_list_from_string(s: str, n: int) -> list:
    # Find the first opening bracket
    start = s.find('[')
    if start == -1:
        raise ValueError("No list found in the string.")
    
    bracket_count = 0
    in_string = False
    string_char = ''
    escaped = False

    # Iterate over the string starting at the first '['
    for i in range(start, len(s)):
        char = s[i]

        if in_string:
            if escaped:
                escaped = False
            elif char == '\\':
                escaped = True
            elif char == string_char:
                in_string = False
        else:
            if char in ('"', "'"):
                in_string = True
                string_char = char
            elif char == '[':
                bracket_count += 1
            elif char == ']':
                bracket_count -= 1
                if bracket_count == 0:
                    # Found the matching closing bracket for the outermost list.
                    list_str = s[start:i+1]
                    # Make string literals single-line and escape necessary characters
                    safe_str = make_string_literal_single_line(list_str)

                    try:
                        # Safely evaluate the list string.
                        output = ast.literal_eval(safe_str)
                        while len(output) < n:
                            output += output
                        output = output[:n]
                        return output
                    except Exception as e:
                        return [{'id':0, 'observation': 'None', 'python': 'None'}] * n
    return [{'id':0, 'observation': 'None', 'python': 'None'}] * n

def extract_list(s) -> list:
    """Extract all valid lists from a string.
    
    Args:
        s (str): A string possibly containing lists.
    
    Returns:
        A list containing all valid lists.
    
    """
    results = []
    s_ = ' '.join(s.split('\n')).strip()
    exp = re.compile(r'(\[.*?\])')
    return exp.findall(s_)


def custom_eval(x):
    if 'inf' in x:
        x = x.replace('inf,','float("inf"),')
        x = x.replace('inf]','float("inf")]')
    if 'array(' in x:
        x = x.replace('array([', '[')
        x = x.replace('])', ']')
    return eval(x)

class Method():
    def __init__(self, model_id, task):
        self.task = task
        
        if 'gpt' in model_id:
            from models.gpt import ParallelGPT
            self.model = ParallelGPT(model_id=model_id)

        elif 'DI' in model_id:
            from models.deepinfra import DeepInfra
            self.model = DeepInfra(model_id=model_id.replace('DI_', ''))

        elif 'gemini' in model_id:
            from models.gemini import Gemini
            self.model = Gemini(model_id=model_id)

        elif 'llama' in model_id:
            from models.llama3 import LlamaModel
            self.model = LlamaModel(model_id=model_id)

        elif 'Qwen' in model_id:
            from models.qwen import QwenModel
            self.model = QwenModel(model_id=model_id)
        


        if 'list_function' in task:
            self.input_format = 'list[int]'
            self.output_format = 'list[int]'
        elif 'miniarc' in task:
            self.input_format = 'list[list[int]]'
            self.output_format = 'list[list[int]]'
        elif '1d_arc' in task:
            self.input_format = 'list[int]'
            self.output_format = 'list[int]'
        elif task == 'acre':
            self.input_format = 'list[str]'
            self.output_format = 'str (one of "on"/"off"/"undetermined")'
        elif 'playgol' in task:
            self.input_format = 'str'
            self.output_format = 'str'
        elif task == 'scan':
            self.input_format = 'str'
            self.output_format = 'str'
        elif 'mbpp_plus' in task:
            self.input_format = 'input'
            self.output_format = 'output'
        elif 'numbergame' in task or 'tenenbaum' in task:
            self.input_format = 'int'
            self.output_format = '0 or 1'
        else:
            raise Exception()


        self.template_code_generation = lambda x, y: f'''You will be given a list of input-output pairs. There is a SINGLE pattern that transforms each input to the corresponding output.
Given the language description of the transformation, implement that description into Python code.
Generate a Python function `fn` that maps the following inputs to their corresponding outputs.

Please format your Python function as follows:

```python
def fn(x):
    # x is {self.input_format}
    # Your code here
    return y # y is {self.output_format}
```

Input-output pairs:
{x}

Implement this transformation:
{y}
'''




class DiscriminationByExample(Method):
    def forward(self, input_batch, temperature=0, num_return_sequences=1, prompt_type='disc_cot', progress_bar=False):

        if prompt_type == 'disc':
            self.template = lambda x, y, z: f'''Based on given input-output pairs, select which of the outputs is most plausible for given test input.
Only output the answer and enclose your answer with ```.

Input-output pairs:
{x}

Test input:
{y}

Test output candidates:
{z}
'''
        elif prompt_type == 'disc_cot':
            self.template = lambda x, y, z: f'''Based on given input-output pairs, select which of the outputs is most plausible for given test input.
Think step-by-step and enclose your answer with ``` at the end of your response.

Input-output pairs:
{x}

Test input:
{y}

Test output candidates:
{z}
'''

        prompts_for_solving = [self.template(str(p), str(i), str(o)) for (p, i, o) in input_batch]
        responses = self.model.generate(prompts_for_solving, temperature=temperature, num_return_sequences=num_return_sequences, progress_bar=progress_bar)['responses']

        output = []
        for response in responses:
            answer = extract_code_blocks(response[0], '')
            output.append({'answer': answer, 'raw_response': response})

        return output


class DiscriminationByExampleAndPrompt(Method):
    def forward(self, input_batch, temperature=0, num_return_sequences=1, prompt_type='disc_cot', progress_bar=False):

        if prompt_type == 'disc':
            self.template = lambda t, x, y, z: f'''Based on given task description and input-output pairs, select which of the outputs is most plausible for given test input. Do not choose error message.
Only output the answer and enclose your answer with ```.

Task description:
{t}

Input-output pairs:
{x}

Test input:
{y}

Test output candidates:
{z}
'''
        elif prompt_type == 'disc_cot':
            self.template = lambda t, x, y, z: f'''Based on given task description and input-output pairs, select which of the outputs is most plausible for given test input. Do not choose error message.
Think step-by-step and enclose your answer with ``` at the end of your response.

Task description:
{t}

Input-output pairs:
{x}

Test input:
{y}

Test output candidates:
{z}
'''

        prompts_for_solving = [self.template(str(t), str(p), str(i), str(o)) for (t, p, i, o) in input_batch]
        responses = self.model.generate(prompts_for_solving, temperature=temperature, num_return_sequences=1, progress_bar=progress_bar)['responses']

        output = []
        for response in responses:
            answer = extract_code_blocks(response[0], '')
            output.append({'answer': answer, 'raw_response': response})

        return output




class TransductionByExample(Method):
    def forward(self, input_batch, temperature=0, num_return_sequences=1, prompt_type='trans_cot', progress_bar=False):

        if prompt_type == 'trans':
            self.template = lambda x, y: f'''Based on the given input-output pairs, generate the output for the given test input.
Only output the answer and enclose your answer with ```.

Input-output pairs:
{x}

Test input:
{y}
'''
        elif prompt_type == 'trans_cot':
            self.template = lambda x, y: f'''Based on the given input-output pairs, generate the output for the given test input.
Think step-by-step and enclose your answer with ``` at the end of your response.

Input-output pairs:
{x}

Test input:
{y}
'''

        prompts_for_solving = [self.template(str(p), str(i)) for (p, i, o) in input_batch]
        responses = self.model.generate(prompts_for_solving, temperature=temperature, num_return_sequences=num_return_sequences, progress_bar=progress_bar)['responses']

        output = []
        for response in responses:
            # answer = extract_code_blocks(response[0], '')
            pattern = re.compile(r"(?:.*)```(.*?)`", re.DOTALL)
            m = pattern.match(response[0])
            try:
                answer = [m.group(1).strip()]
            except:
                answer = ['']



            output.append({'answer': answer, 'raw_response': response})

        return output


class TransductionByExampleAndPrompt(Method):
    def forward(self, input_batch, temperature=0, num_return_sequences=1, prompt_type='trans_cot', progress_bar=False):

        if prompt_type == 'trans':
            self.template = lambda t, x, y: f'''Based on the given input-output pairs, generate the output for the given test input.
Only output the answer and enclose your answer with ```.

Task description:
{t}

Input-output pairs:
{x}

Test input:
{y}
'''
        elif prompt_type == 'trans_cot':
            self.template = lambda t, x, y: f'''Based on the given input-output pairs, generate the output for the given test input.
Think step-by-step and enclose your answer with ``` at the end of your response.

Task description:
{t}

Input-output pairs:
{x}

Test input:
{y}
'''

        prompts_for_solving = [self.template(str(t), str(p), str(i)) for (t, p, i, o) in input_batch]
        responses = self.model.generate(prompts_for_solving, temperature=temperature, num_return_sequences=1, progress_bar=progress_bar)['responses']

        output = []
        for response in responses:
            # answer = extract_code_blocks(response[0], '')
            pattern = re.compile(r"(?:.*)```(.*?)`", re.DOTALL)
            m = pattern.match(response[0])
            try:
                answer = [m.group(1).strip()]
            except:
                answer = ['']



            output.append({'answer': answer, 'raw_response': response})

        return output






class DeductiveCodeGen(Method):
    def forward(self, input_task_list, input_io_list, temperature=1, num_return_sequences=8, top_p=1):

        self.template_hypothesis_generation = lambda task, io: f'''You will be given a Python coding task and a list of input-output pairs.
Please format your Python function as follows:

```python
def fn(x):
    # x is {self.input_format}
    # Your code here
    return y # y is {self.output_format}
```

Task:
{task}
Input-output pairs:
{io}
Python function:

'''

        prompts_for_solving = [self.template_hypothesis_generation(str(input_task_list), str(input_text)) for input_task_list, input_text in zip(input_task_list, input_io_list)]
        responses = self.model.generate(prompts_for_solving, temperature=temperature, num_return_sequences=num_return_sequences, top_p=top_p)['responses'] # [num_input, num_seq]

        output = []
        for i in range(len(input_task_list)):
            hypos, codes, vis, raw_responses = [], [], [], []
            for j in range(num_return_sequences):
                response = responses[i][j]
                hypo, code = extract_code_blocks(response, 'hypothesis'), extract_code_blocks(response, 'python')
                if len(hypo) > 0 and len(code) > 0:
                    hypos.append(hypo[0])
                    codes.append(code[0])
                    vis.append('HYPOTHESIS:\n' + hypo[0] + ' | CODE:\n' + code[0])
                elif len(code) > 0:
                    hypos.append('parsing error')
                    codes.append(code[0])
                    vis.append('HYPOTHESIS:\n' + 'parsing error' + ' | CODE:\n' + code[0])
                else:
                    hypos.append('')
                    codes.append('')
                    vis.append('parsing error')
                raw_responses.append(response)

            output.append(
                {
                    'hypothesis': hypos,
                    'code': codes,
                    'output_visualization': vis,
                    'raw_response': raw_responses,
                }
            )

        return output





class HC(Method):
    def forward(self, input_text_list, temperature=1, num_return_sequences=8, top_p=1):

        self.template_hypothesis_generation = lambda x: f'''You will be given a list of input-output pairs. There is a SINGLE pattern that transforms each input to the corresponding output.
First, output a hypothesis for the transformation in natural language form.
Then, generate a Python function `fn` that maps the following inputs to their corresponding outputs.

Please format your hypothesis and Python function as follows:

```hypothesis
HYPOTHESIS
```

```python
def fn(x):
    # x is {self.input_format}
    # Your code here
    return y # y is {self.output_format}
```

Input-output pairs:
{x}
Hypothesis and Python function:

'''



        prompts_for_solving = [self.template_hypothesis_generation(str(input_text)) for input_text in input_text_list]
        responses = self.model.generate(prompts_for_solving, temperature=temperature, num_return_sequences=num_return_sequences, top_p=top_p)['responses'] # [num_input, num_seq]

        output = []
        for i in range(len(input_text_list)):
            hypos, codes, vis, raw_responses = [], [], [], []
            for j in range(num_return_sequences):
                response = responses[i][j]
                hypo, code = extract_code_blocks(response, 'hypothesis'), extract_code_blocks(response, 'python')
                if len(hypo) > 0 and len(code) > 0:
                    hypos.append(hypo[0])
                    codes.append(code[0])
                    vis.append('HYPOTHESIS:\n' + hypo[0] + ' | CODE:\n' + code[0])
                elif len(code) > 0:
                    hypos.append('parsing error')
                    codes.append(code[0])
                    vis.append('HYPOTHESIS:\n' + 'parsing error' + ' | CODE:\n' + code[0])
                else:
                    hypos.append('')
                    codes.append('')
                    vis.append('parsing error')
                raw_responses.append(response)

            output.append(
                {
                    'hypothesis': hypos,
                    'code': codes,
                    'output_visualization': vis,
                    'raw_response': raw_responses,
                }
            )

        return output










class CHC(Method):
    def __init__(self, model_id, task, biglittle=False):
        super().__init__(model_id, task)
        if biglittle:
            from models.gpt import ParallelGPT
            self.concept_model = ParallelGPT(model_id='gpt-4o-2024-08-06')
        else:
            self.concept_model = self.model

        self.model_id = model_id

    def forward(self, input_text_list, temperature=1, top_p=1, num_concepts=32, num_sampling_per_concept=4, prompt_type='ele', mix_ori=False):
        if mix_ori:
            num_concepts -= 1

        if prompt_type == 'ele':
            if num_concepts == 1:
                self.template_concept_generation = lambda x: f'''You will be given a list of input-output pairs. There is a SINGLE pattern that transforms each input to the corresponding output.
Generate an elementary concept that may be useful to induce the transformation pattern.
Format your response in json format (dictionary whose key is 0 and value is elementary concept).

Input-output pairs:
{x}

Elementary concept:

'''
            else:
                self.template_concept_generation = lambda x: f'''You will be given a list of input-output pairs. There is a SINGLE pattern that transforms each input to the corresponding output.
List {num_concepts} elementary concepts that may be useful to induce the transformation pattern.
Format your response in json format (dictionary whose keys are indices and values are elementary concepts).

Input-output pairs:
{x}

Elementary concepts:

'''

        if prompt_type == 'diverse_simple' or self.model_id == 'gpt-4o-2024-08-06':
            if num_concepts == 1:
                self.template_concept_generation = lambda x: f'''You will be given a list of input-output pairs. There is a SINGLE pattern that transforms each input to the corresponding output.
Generate an elementary concept that may be useful to induce the transformation pattern.
The concept should be simple and concise.
Format your response in json format (dictionary whose key is 0 and value is elementary concept).

Input-output pairs:
{x}

Elementary concept:

'''
            else:
                self.template_concept_generation = lambda x: f'''You will be given a list of input-output pairs. There is a SINGLE pattern that transforms each input to the corresponding output.
List {num_concepts} elementary concepts that may be useful to induce the transformation pattern.
The concepts should be diverse, simple and concise.
Format your response in json format (dictionary whose keys are indices and values are elementary concepts).

Input-output pairs:
{x}

Elementary concepts:

'''


        self.template_hypothesis_generation = lambda x: f'''You will be given a list of input-output pairs. There is a SINGLE pattern that transforms each input to the corresponding output.
First, output a hypothesis for the transformation in natural language form.
Then, generate a Python function `fn` that maps the following inputs to their corresponding outputs.

Please format your hypothesis and Python function as follows:

```hypothesis
HYPOTHESIS
```

```python
def fn(x):
    # x is {self.input_format}
    # Your code here
    return y # y is {self.output_format}
```

Input-output pairs:
{x}
Hypothesis and Python function:

'''

        self.template_hypothesis_generation_with_hint = lambda x, hint: f'''You will be given a list of input-output pairs. There is a SINGLE pattern that transforms each input to the corresponding output.
First, output a hypothesis for the transformation in natural language form. Use hint: {hint}.
Then, generate a Python function `fn` that maps the following inputs to their corresponding outputs.

Please format your hypothesis and Python function as follows:

```hypothesis
HYPOTHESIS
```

```python
def fn(x):
    # x is {self.input_format}
    # Your code here
    return y # y is {self.output_format}
```

Input-output pairs:
{x}
Hypothesis and Python function:

'''

        prompts_for_concept_generation = [self.template_concept_generation(str(input_text)) for input_text in input_text_list]
        concepts_list = self.concept_model.generate(prompts_for_concept_generation, temperature=temperature, num_return_sequences=1, top_p=top_p)['responses'] # [num_input, 1]
        concepts_parsed_list = []
        for concepts in concepts_list:
            # concepts_parsed = extract_code_blocks(concepts[0], 'json')
            if prompt_type == 'list':
                concepts_parsed = extract_list(concepts[0])
            else:
                concepts_parsed = extract_dict(concepts[0])

            if len(concepts_parsed) > 0:
                try:
                    concepts_parsed = custom_eval(concepts_parsed[0])
                    if prompt_type != 'list':
                        concepts_parsed = [v for k, v in concepts_parsed.items()]
                    while(1):
                        if len(concepts_parsed) < num_concepts:
                            concepts_parsed += concepts_parsed
                        else:
                            break
                    concepts_parsed = concepts_parsed[:num_concepts]
                except:
                    print('parsing error!')
                    concepts_parsed = [''] * num_concepts
            else:
                print('parsing error!')
                concepts_parsed = [''] * num_concepts


            concepts_parsed_list.append(concepts_parsed)

        # import random; random.shuffle(concepts_parsed_list)
        
        prompts_for_solving = []
        for input_text, concepts_parsed in zip(input_text_list, concepts_parsed_list):
            if mix_ori:
                prompts_for_solving.append(self.template_hypothesis_generation(str(input_text)))
            for concept in concepts_parsed:
                prompts_for_solving.append(self.template_hypothesis_generation_with_hint(str(input_text), concept))

        if mix_ori: num_concepts += 1

        responses = self.model.generate(prompts_for_solving, temperature=temperature, num_return_sequences=num_sampling_per_concept, top_p=top_p)['responses'] # [num_input * num_concepts, num_sampling_per_concept]
        responses = reshape_list(flatten_list(responses), [len(input_text_list), num_concepts * num_sampling_per_concept])
        num_return_sequences = num_concepts * num_sampling_per_concept

        output = []
        for i in range(len(input_text_list)):
            hypos, codes, vis, raw_responses = [], [], [], []
            for j in range(num_return_sequences):
                response = responses[i][j]
                hypo, code = extract_code_blocks(response, 'hypothesis'), extract_code_blocks(response, 'python')
                if len(hypo) > 0 and len(code) > 0:
                    hypos.append(hypo[0])
                    codes.append(code[0])
                    vis.append('HYPOTHESIS:\n' + hypo[0] + ' | CODE:\n' + code[0])
                elif len(code) > 0:
                    hypos.append('parsing error')
                    codes.append(code[0])
                    vis.append('HYPOTHESIS:\n' + 'parsing error' + ' | CODE:\n' + code[0])
                else:
                    hypos.append('')
                    codes.append('')
                    vis.append('parsing error')
                raw_responses.append(response)

            output.append(
                {
                    'hypothesis': hypos,
                    'code': codes,
                    'output_visualization': vis,
                    'raw_response': {'concepts': concepts_parsed_list[i], 'raw_responses': raw_responses, 'raw_concepts': concepts_list[i]},
                }
            )

        return output







class DeductiveCHC(Method):
    def __init__(self, model_id, task, biglittle=False):
        super().__init__(model_id, task)
        if biglittle:
            from models.gpt import ParallelGPT
            self.concept_model = ParallelGPT(model_id='gpt-4o-2024-08-06')
        else:
            self.concept_model = self.model

        self.model_id = model_id

    def forward(self, input_task_list, input_io_list, temperature=1, top_p=1, num_concepts=32, num_sampling_per_concept=4, prompt_type='ele', mix_ori=False):
        if mix_ori:
            num_concepts -= 1

        if prompt_type == 'ele':
            if num_concepts == 1:
                self.template_concept_generation = lambda task, io: f'''You will be given a natural language task description and a list of input-output pairs.
Generate an elementary concept that may be useful to induce the transformation pattern.
Format your response in json format (dictionary whose key is 0 and value is elementary concept).

Task:
{task}

Input-output pairs:
{io}

Elementary concept:

'''
            else:
                self.template_concept_generation = lambda task, io: f'''You will be given a natural language task description and a list of input-output pairs.
List {num_concepts} elementary concepts that may be useful to induce the transformation pattern.
Format your response in json format (dictionary whose keys are indices and values are elementary concepts).

Task:
{task}

Input-output pairs:
{io}

Elementary concepts:

'''



        self.template_hypothesis_generation = lambda task, io: f'''You will be given a list of input-output pairs. There is a SINGLE pattern that transforms each input to the corresponding output.
First, output a hypothesis for the transformation in natural language form.
Then, generate a Python function `fn` that maps the following inputs to their corresponding outputs.

Please format your hypothesis and Python function as follows:

```hypothesis
HYPOTHESIS
```

```python
def fn(x):
    # x is {self.input_format}
    # Your code here
    return y # y is {self.output_format}
```

Task:
{task}

Input-output pairs:
{io}

Hypothesis and Python function:

'''

        self.template_hypothesis_generation_with_hint = lambda task, io, hint: f'''You will be given a list of input-output pairs. There is a SINGLE pattern that transforms each input to the corresponding output.
First, output a hypothesis for the transformation in natural language form. Use hint: {hint}.
Then, generate a Python function `fn` that maps the following inputs to their corresponding outputs.

Please format your hypothesis and Python function as follows:

```hypothesis
HYPOTHESIS
```

```python
def fn(x):
    # x is {self.input_format}
    # Your code here
    return y # y is {self.output_format}
```

Task:
{task}

Input-output pairs:
{io}

Hypothesis and Python function:

'''

        prompts_for_concept_generation = [self.template_concept_generation(t, str(i)) for t, i in zip(input_task_list, input_io_list)]
        concepts_list = self.concept_model.generate(prompts_for_concept_generation, temperature=temperature, num_return_sequences=1, top_p=top_p)['responses'] # [num_input, 1]
        concepts_parsed_list = []
        for concepts in concepts_list:
            # concepts_parsed = extract_code_blocks(concepts[0], 'json')
            if prompt_type == 'list':
                concepts_parsed = extract_list(concepts[0])
            else:
                concepts_parsed = extract_dict(concepts[0])

            if len(concepts_parsed) > 0:
                try:
                    concepts_parsed = custom_eval(concepts_parsed[0])
                    if prompt_type != 'list':
                        concepts_parsed = [v for k, v in concepts_parsed.items()]
                    while(1):
                        if len(concepts_parsed) < num_concepts:
                            concepts_parsed += concepts_parsed
                        else:
                            break
                    concepts_parsed = concepts_parsed[:num_concepts]
                except:
                    print('parsing error!')
                    concepts_parsed = [''] * num_concepts
            else:
                print('parsing error!')
                concepts_parsed = [''] * num_concepts


            concepts_parsed_list.append(concepts_parsed)

        # import random; random.shuffle(concepts_parsed_list)
        
        prompts_for_solving = []
        for t, i, concepts_parsed in zip(input_task_list, input_io_list, concepts_parsed_list):
            if mix_ori:
                prompts_for_solving.append(self.template_hypothesis_generation(t, str(i)))
            for concept in concepts_parsed:
                prompts_for_solving.append(self.template_hypothesis_generation_with_hint(t, str(i), concept))

        if mix_ori: num_concepts += 1

        responses = self.model.generate(prompts_for_solving, temperature=temperature, num_return_sequences=num_sampling_per_concept, top_p=top_p)['responses'] # [num_input * num_concepts, num_sampling_per_concept]
        responses = reshape_list(flatten_list(responses), [len(input_io_list), num_concepts * num_sampling_per_concept])
        num_return_sequences = num_concepts * num_sampling_per_concept

        output = []
        for i in range(len(input_io_list)):
            hypos, codes, vis, raw_responses = [], [], [], []
            for j in range(num_return_sequences):
                response = responses[i][j]
                hypo, code = extract_code_blocks(response, 'hypothesis'), extract_code_blocks(response, 'python')
                if len(hypo) > 0 and len(code) > 0:
                    hypos.append(hypo[0])
                    codes.append(code[0])
                    vis.append('HYPOTHESIS:\n' + hypo[0] + ' | CODE:\n' + code[0])
                elif len(code) > 0:
                    hypos.append('parsing error')
                    codes.append(code[0])
                    vis.append('HYPOTHESIS:\n' + 'parsing error' + ' | CODE:\n' + code[0])
                else:
                    hypos.append('')
                    codes.append('')
                    vis.append('parsing error')
                raw_responses.append(response)

            output.append(
                {
                    'hypothesis': hypos,
                    'code': codes,
                    'output_visualization': vis,
                    'raw_response': {'concepts': concepts_parsed_list[i], 'raw_responses': raw_responses, 'raw_concepts': concepts_list[i]},
                }
            )

        return output







class DeductiveATC(Method):
    def __init__(self, model_id, task, biglittle=False):
        super().__init__(model_id, task)
        if biglittle:
            from models.gpt import ParallelGPT
            self.concept_model = ParallelGPT(model_id='gpt-4o-2024-08-06')
        else:
            self.concept_model = self.model

        self.model_id = model_id

    def forward(self, input_task_list, input_io_list, temperature=1, top_p=1, num_concepts=32, concept_num_sampling=1):

        self.template_concept_generation = lambda task, io: f'''You will be given a Python coding task and a list of input-output pairs. There are multiple algorithms(implementations) that performs given task.
Generate {num_concepts} algorithms for the task in natural language form.
These algorithms should be distinct; they map the given input to the output but implemented in various ways.

Please format your algorithms as follows:

{{
1: "algorithm",
2: "algorithm",
...
}}


Task:
{task}

Input-output pairs:
{io}

Algorithms:
'''

        self.template_hypothesis_generation_with_hint = lambda x, y: f'''You will be given a list of input-output pairs and an algorithm described in natural language.
Implement the given algorithm in a Python function `fn` that maps the following inputs to their corresponding outputs.

Please format your Python function as follows:

```python
def fn(x):
    # Your code here
    return y
```

Input-output pairs:
{x}

Algorithm:
{y}

Python function:
'''

        prompts_for_concept_generation = [self.template_concept_generation(str(task), str(io)) for task, io in zip(input_task_list, input_io_list)]
        concepts_list = self.concept_model.generate(prompts_for_concept_generation, temperature=temperature, num_return_sequences=concept_num_sampling, top_p=top_p)['responses'] # [num_input, 1]
        concepts_parsed_list = []
        for concepts in concepts_list:
            concepts_parsed_ = []
            for concept in concepts:
                concepts_parsed = extract_dict(concept)

                if len(concepts_parsed) > 0:
                    try:
                        concepts_parsed = custom_eval(concepts_parsed[0])
                        concepts_parsed = [v for k, v in concepts_parsed.items()]
                        while(1):
                            if len(concepts_parsed) < num_concepts:
                                concepts_parsed += concepts_parsed
                            else:
                                break
                        concepts_parsed = concepts_parsed[:num_concepts]
                    except:
                        print('parsing error!')
                        concepts_parsed = [''] * num_concepts
                else:
                    print('parsing error!')
                    concepts_parsed = [''] * num_concepts
                concepts_parsed_ += concepts_parsed

            concepts_parsed_list.append(concepts_parsed_)

        # import random; random.shuffle(concepts_parsed_list)
        
        prompts_for_solving = []
        for input_text, concepts_parsed in zip(input_io_list, concepts_parsed_list):
            for concept in concepts_parsed:
                prompts_for_solving.append(self.template_hypothesis_generation_with_hint(str(input_text), concept))

        responses = self.model.generate(prompts_for_solving, temperature=temperature, num_return_sequences=1, top_p=top_p)['responses'] # [num_input * num_concepts, num_sampling_per_concept]
        responses = reshape_list(flatten_list(responses), [len(input_task_list), concept_num_sampling * num_concepts])
        num_return_sequences = concept_num_sampling * num_concepts

        output = []
        for i in range(len(input_task_list)):
            hypos, codes, vis, raw_responses = [], [], [], []
            for j in range(num_return_sequences):
                response = responses[i][j]
                hypo, code = [concepts_parsed_list[i][j]], extract_code_blocks(response, 'python')
                if len(hypo) > 0 and len(code) > 0:
                    hypos.append(hypo[0])
                    codes.append(code[0])
                    vis.append('HYPOTHESIS:\n' + hypo[0] + ' | CODE:\n' + code[0])
                elif len(code) > 0:
                    hypos.append('parsing error')
                    codes.append(code[0])
                    vis.append('HYPOTHESIS:\n' + 'parsing error' + ' | CODE:\n' + code[0])
                else:
                    hypos.append('')
                    codes.append('')
                    vis.append('parsing error')
                raw_responses.append(response)

            output.append(
                {
                    'hypothesis': hypos,
                    'code': codes,
                    'output_visualization': vis,
                    'raw_response': {'concepts': concepts_parsed_list[i], 'raw_responses': raw_responses, 'raw_concepts': concepts_list[i]},
                }
            )

        return output








class ATC(Method):
    def __init__(self, model_id, task, biglittle=False):
        super().__init__(model_id, task)
        if biglittle:
            from models.gpt import ParallelGPT
            self.concept_model = ParallelGPT(model_id='gpt-4o-2024-08-06')
        else:
            self.concept_model = self.model

        self.model_id = model_id

    def forward(self, input_text_list, temperature=1, top_p=1, num_concepts=32, concept_num_sampling=1):

        self.template_concept_generation = lambda x: f'''You will be given a list of input-output pairs. There are multiple algorithms that transform each input to the corresponding output.
Generate {num_concepts} algorithms for the transformation in natural language form.
These algorithms should be distinct; they map the given input to the output but implemented in various ways.

Please format your algorithms as follows:

{{
1: "algorithm",
2: "algorithm",
...
}}

Input-output pairs:
{x}

Algorithms:

'''

        self.template_hypothesis_generation_with_hint = lambda x, y: f'''You will be given a list of input-output pairs and an algorithm described in natural language.
Implement the given algorithm in a Python function `fn` that maps the following inputs to their corresponding outputs.

Please format your Python function as follows:

```python
def fn(x):
    # x is {self.input_format}
    # Your code here
    return y # y is {self.output_format}
```

Input-output pairs:
{x}

Algorithm:
{y}

Python function:
'''

        prompts_for_concept_generation = [self.template_concept_generation(str(input_text)) for input_text in input_text_list]
        concepts_list = self.concept_model.generate(prompts_for_concept_generation, temperature=temperature, num_return_sequences=concept_num_sampling, top_p=top_p)['responses'] # [num_input, 1]
        concepts_parsed_list = []
        for concepts in concepts_list:
            concepts_parsed_ = []
            for concept in concepts:
                concepts_parsed = extract_dict(concept)

                if len(concepts_parsed) > 0:
                    try:
                        concepts_parsed = custom_eval(concepts_parsed[0])
                        concepts_parsed = [v for k, v in concepts_parsed.items()]
                        while(1):
                            if len(concepts_parsed) < num_concepts:
                                concepts_parsed += concepts_parsed
                            else:
                                break
                        concepts_parsed = concepts_parsed[:num_concepts]
                    except:
                        print('parsing error!')
                        concepts_parsed = [''] * num_concepts
                else:
                    print('parsing error!')
                    concepts_parsed = [''] * num_concepts
                concepts_parsed_ += concepts_parsed

            concepts_parsed_list.append(concepts_parsed_)

        # import random; random.shuffle(concepts_parsed_list)
        
        prompts_for_solving = []
        for input_text, concepts_parsed in zip(input_text_list, concepts_parsed_list):
            for concept in concepts_parsed:
                prompts_for_solving.append(self.template_hypothesis_generation_with_hint(str(input_text), concept))

        responses = self.model.generate(prompts_for_solving, temperature=temperature, num_return_sequences=1, top_p=top_p)['responses'] # [num_input * num_concepts, num_sampling_per_concept]
        responses = reshape_list(flatten_list(responses), [len(input_text_list), concept_num_sampling * num_concepts])
        num_return_sequences = concept_num_sampling * num_concepts

        output = []
        for i in range(len(input_text_list)):
            hypos, codes, vis, raw_responses = [], [], [], []
            for j in range(num_return_sequences):
                response = responses[i][j]
                hypo, code = [concepts_parsed_list[i][j]], extract_code_blocks(response, 'python')
                if len(hypo) > 0 and len(code) > 0:
                    hypos.append(hypo[0])
                    codes.append(code[0])
                    vis.append('HYPOTHESIS:\n' + hypo[0] + ' | CODE:\n' + code[0])
                elif len(code) > 0:
                    hypos.append('parsing error')
                    codes.append(code[0])
                    vis.append('HYPOTHESIS:\n' + 'parsing error' + ' | CODE:\n' + code[0])
                else:
                    hypos.append('')
                    codes.append('')
                    vis.append('parsing error')
                raw_responses.append(response)

            output.append(
                {
                    'hypothesis': hypos,
                    'code': codes,
                    'output_visualization': vis,
                    'raw_response': {'concepts': concepts_parsed_list[i], 'raw_responses': raw_responses, 'raw_concepts': concepts_list[i]},
                }
            )

        return output



