import ast, os, random
from methods import HC, CHC, ATC, DeductiveCodeGen, DeductiveATC, DeductiveCHC
import argparse, json, copy
from tqdm import tqdm
import pdb
from collections import Counter
import numpy as np
from utils.python_utils import execute_function, extract_program
import math
import signal
import sys

class TimeoutException(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutException("Execution timed out!")

def calculate_entropy(int_list):
    # Calculate the sum of the integers (unnormalized probabilities)
    total_sum = sum(int_list)
    
    # Normalize the probabilities
    probabilities = [x / total_sum for x in int_list]
    
    # Calculate the entropy
    entropy = -sum(p * math.log(p, 2) for p in probabilities if p > 0)
    
    return entropy

def custom_eval(x):
    if 'inf' in x:
        x = x.replace('inf,','float("inf"),')
        x = x.replace('inf]','float("inf")]')
    if 'array(' in x:
        x = x.replace('array([', '[')
        x = x.replace('])', ']')
    return eval(x)


parser = argparse.ArgumentParser()
# parser.add_argument("-ds", "--dataset", dest="dataset", type=str, action="store", default='playgol_v2')
parser.add_argument("-ds", "--dataset", dest="dataset", type=str, action="store", default='mbpp_plus_51_cases')

parser.add_argument("-mi", "--model_id", dest="model_id", type=str, action="store", default='gpt-4o-mini-2024-07-18')

# parser.add_argument("-exp", "--experiment", dest="experiment", type=str, action="store", default='ATC')
parser.add_argument("-exp", "--experiment", dest="experiment", type=str, action="store", default='DeductiveATC')
# parser.add_argument("-exp", "--experiment", dest="experiment", type=str, action="store", default='HC')
# parser.add_argument("-exp", "--experiment", dest="experiment", type=str, action="store", default='DeductiveCodeGen')
# parser.add_argument("-exp", "--experiment", dest="experiment", type=str, action="store", default='CHC')
# parser.add_argument("-exp", "--experiment", dest="experiment", type=str, action="store", default='DeductiveCHC')

parser.add_argument("-c", "--num_concepts", dest="num_concepts", type=int, action="store", default=4)
parser.add_argument("-s", "--sample_num", dest="sample_num", type=int, action="store", default=8)
parser.add_argument("-pt", "--prompt_type", dest="prompt_type", type=str, action="store", default='ele')
parser.add_argument("-temp", "--temperature", dest="temperature", type=float, action="store", default=1.0)
parser.add_argument("-topp", "--top_p", dest="top_p", type=float, action="store", default=1.0)

parser.add_argument("-outpath", "--output_path", dest="output_path", type=str, action="store")
args = parser.parse_args()


DATASET = args.dataset
TEMP = args.temperature if args.temperature != int(args.temperature) else int(args.temperature)
TOP_P = args.top_p if args.top_p != int(args.top_p) else int(args.top_p)
# HC
SAMPLE_NUM = args.sample_num
# CHC
NUM_CONCEPTS = args.num_concepts
NUM_SAMPLE_PER_CONCEPT = args.sample_num
PROMPT_TYPE = args.prompt_type
BIGLITTLE = False
MIX_ORI = False
MODEL_ID = args.model_id


if 'playgol' in args.dataset:
    NUM_TRAIN = 1
    NUM_TEST = 4
elif 'mbpp' in args.dataset:
    NUM_TRAIN = 1
    NUM_TEST = 50

if args.experiment == 'HC':
    METHOD = f'HC_{SAMPLE_NUM}s_{TEMP}t'
    model = HC(model_id=MODEL_ID, task=DATASET)
elif args.experiment == 'CHC':
    if PROMPT_TYPE == 'ele':
        METHOD = f'CHC_{NUM_CONCEPTS}c_{NUM_SAMPLE_PER_CONCEPT}s_{TEMP}t'
    else:
        METHOD = f'CHC_{NUM_CONCEPTS}c_{NUM_SAMPLE_PER_CONCEPT}s_{PROMPT_TYPE}_{TEMP}t'
    model = CHC(model_id=MODEL_ID, task=DATASET, biglittle=BIGLITTLE)
elif args.experiment == 'ATC':
    if BIGLITTLE:
        METHOD = f'ATC_bl_{NUM_CONCEPTS}a_{NUM_SAMPLE_PER_CONCEPT}s_{TEMP}t'
    else:
        METHOD = f'ATC_{NUM_CONCEPTS}a_{NUM_SAMPLE_PER_CONCEPT}s_{TEMP}t'
    model = ATC(model_id=MODEL_ID, task=DATASET, biglittle=BIGLITTLE)
elif args.experiment == 'DeductiveATC':
    if BIGLITTLE:
        METHOD = f'DeductiveATC_bl_{NUM_CONCEPTS}a_{NUM_SAMPLE_PER_CONCEPT}s_{TEMP}t'
    else:
        METHOD = f'DeductiveATC_{NUM_CONCEPTS}a_{NUM_SAMPLE_PER_CONCEPT}s_{TEMP}t'
    model = DeductiveATC(model_id=MODEL_ID, task=DATASET, biglittle=BIGLITTLE)
elif args.experiment == 'DeductiveCodeGen':
    METHOD = f'DeductiveCodeGen_{SAMPLE_NUM}s_{TEMP}t'
    model = DeductiveCodeGen(model_id=MODEL_ID, task=DATASET)
elif args.experiment == 'DeductiveCHC':
    if PROMPT_TYPE == 'ele':
        METHOD = f'DeductiveCHC_{NUM_CONCEPTS}c_{NUM_SAMPLE_PER_CONCEPT}s_{TEMP}t'
    else:
        METHOD = f'DeductiveCHC_{NUM_CONCEPTS}c_{NUM_SAMPLE_PER_CONCEPT}s_{PROMPT_TYPE}_{TEMP}t'
    model = DeductiveCHC(model_id=MODEL_ID, task=DATASET, biglittle=BIGLITTLE)




os.makedirs(os.path.dirname(args.output_path), exist_ok=True)

def get_python_input(input):
    return ast.literal_eval(input)

with open(f'data/{DATASET}.jsonl', 'r') as f:
    data = [json.loads(x) for x in list(f)]

output_list = []

num_correct_task = 0
inputs_batch = []
inputs_batch_task = []

if 'Deductive' in args.experiment:
    for problem in data:
        inputs_batch_task.append(problem['prompt'])
for problem in data:
    if 'mbpp' in DATASET:
        problem['train'] = ast.literal_eval(problem['train'])
        try:
            problem['test'] = ast.literal_eval(problem['test'])
        except:
            problem['test'] = eval(problem['test'])
    io_pairs = problem['train'] + problem['test']
    train_io_pairs = io_pairs[:NUM_TRAIN]
    test_io_pairs = io_pairs[NUM_TRAIN:NUM_TRAIN + NUM_TEST]
    inputs_batch.append(train_io_pairs)

if args.experiment == 'HC':
    model_output = model.forward(inputs_batch, temperature=TEMP, num_return_sequences=SAMPLE_NUM, top_p=TOP_P)
elif args.experiment == 'CHC':
    model_output = model.forward(
        inputs_batch, 
        temperature=TEMP, 
        top_p=TOP_P, 
        num_concepts=NUM_CONCEPTS, 
        num_sampling_per_concept=NUM_SAMPLE_PER_CONCEPT,
        prompt_type=PROMPT_TYPE,
        mix_ori=MIX_ORI
    )
elif args.experiment == 'ATC':
    model_output = model.forward(
        inputs_batch, 
        temperature=TEMP, 
        top_p=TOP_P, 
        num_concepts=NUM_CONCEPTS, 
        concept_num_sampling=SAMPLE_NUM,
    )
elif args.experiment == 'DeductiveATC':
    model_output = model.forward(
        inputs_batch_task, 
        inputs_batch,
        temperature=TEMP, 
        top_p=TOP_P, 
        num_concepts=NUM_CONCEPTS, 
        concept_num_sampling=SAMPLE_NUM,
    )
elif args.experiment == 'DeductiveCodeGen':
    model_output = model.forward(inputs_batch_task, inputs_batch, temperature=TEMP, num_return_sequences=SAMPLE_NUM, top_p=TOP_P)
elif args.experiment == 'DeductiveCHC':
    model_output = model.forward(
        inputs_batch_task, 
        inputs_batch,
        temperature=TEMP, 
        top_p=TOP_P, 
        num_concepts=NUM_CONCEPTS, 
        num_sampling_per_concept=NUM_SAMPLE_PER_CONCEPT,
        prompt_type=PROMPT_TYPE,
        mix_ori=MIX_ORI
    )



entropy_list, num_unique_program_list = [], []
for i, problem in enumerate(tqdm(data)):
    hypos, programs = model_output[i]['hypothesis'], model_output[i]['code']

    io_pairs = problem['train'] + problem['test']
    train_io_pairs = io_pairs[:NUM_TRAIN]
    test_io_pairs = io_pairs[NUM_TRAIN:NUM_TRAIN + NUM_TEST]

    train_inputs = copy.deepcopy([ex["input"] for ex in train_io_pairs])
    train_outputs = copy.deepcopy([ex["output"] for ex in train_io_pairs])
    test_inputs = copy.deepcopy([ex["input"] for ex in test_io_pairs])
    test_outputs = copy.deepcopy([ex["output"] for ex in test_io_pairs])
    
    programs_and_scores = []
    execution_results = []
    train_accuracies, test_accuracies = [], []
    for idx, program in enumerate(programs):
        train_accuracy = 0
        test_accuracy = 0




        signal.signal(signal.SIGALRM, timeout_handler)
        signal.alarm(120)  # Set the timeout

        try:
            # Attempt to execute the code
            execution_results_on_train_inputs = execute_function(program, train_inputs)
            execution_results_on_test_inputs = execute_function(program, test_inputs)
        except TimeoutException:
            # Handle what happens if it times out
            execution_results_on_train_inputs = 'timeout'
            execution_results_on_test_inputs = 'timeout'
        finally:
            # Always disable the alarm
            signal.alarm(0)




        execution_results.append(str(execution_results_on_train_inputs + execution_results_on_test_inputs))
        for exe, out in zip(execution_results_on_train_inputs, train_outputs):
            try:
                if exe == out:
                    train_accuracy += 1
            except:
                pass
        for exe, out in zip(execution_results_on_test_inputs, test_outputs):
            try:
                if exe == out:
                    test_accuracy += 1
            except:
                pass
        train_accuracy /= len(train_inputs)
        test_accuracy /= len(test_inputs)
        train_accuracies.append(train_accuracy)
        test_accuracies.append(test_accuracy)

        programs_and_scores.append((program, train_accuracy, idx))

    sorted_programs_and_scores = sorted(programs_and_scores, key=lambda x: x[1], reverse=True)
    best_program = sorted_programs_and_scores[0][0]
    best_idx = sorted_programs_and_scores[0][2]
    best_output = model_output[i]['output_visualization'][best_idx]

    test_accuracy = test_accuracies[best_idx]
    success = (test_accuracies[best_idx] * train_accuracies[best_idx])//1

    cluster = Counter(execution_results)
    semantic_distribution = sorted([v for k, v in cluster.items()], reverse=True)
    entropy = calculate_entropy(semantic_distribution)
    num_unique_program = len(semantic_distribution)
    unique_execution_results = [k for k, v in cluster.items()]

    entropy_list.append(entropy)
    num_unique_program_list.append(num_unique_program)

    output = {
        'idx': problem['idx'],
        'best_output': best_output,
        'train_accuracy': sorted_programs_and_scores[0][1],
        'test_accuracy': test_accuracy,
        'success': success,
        # 'model_outputs': model_output[i]['output_visualization'],
        'raw_response': model_output[i]['raw_response'],
        'execution_results': str(execution_results),
        'train_accuracies': train_accuracies,
        'test_accuracies': test_accuracies,
        'unique_execution_results': unique_execution_results,
        'num_unique_programs': num_unique_program
    }

    output_list.append(output)


with open(args.output_path, 'w') as f:
    json.dump(output_list, f, indent=4)


task_acc = 0
for instance in output_list:
    task_acc += instance['success'] // 1
print('task accuracy:', task_acc/len(data))

# print(f'avg entropy: {sum(entropy_list)/len(entropy_list)}')
print(f'avg unique programs: {sum(num_unique_program_list)/len(num_unique_program_list)}')

