import ast, os, random
from methods import HC, CHC, ATC, DiscriminationByExample, DiscriminationByExampleAndPrompt
import argparse, json, copy
from tqdm import tqdm
import pdb
from collections import Counter
import numpy as np
from utils.python_utils import execute_function, extract_program
import math
import signal
import sys
from difflib import SequenceMatcher as SM


class TimeoutException(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutException("Execution timed out!")

def calculate_entropy(int_list):
    # Calculate the sum of the integers (unnormalized probabilities)
    total_sum = sum(int_list)
    
    # Normalize the probabilities
    probabilities = [x / total_sum for x in int_list]
    
    # Calculate the entropy
    entropy = -sum(p * math.log(p, 2) for p in probabilities if p > 0)
    
    return entropy

def custom_eval(x):
    if 'inf' in x:
        x = x.replace('inf,','float("inf"),')
        x = x.replace('inf]','float("inf")]')
    if 'array(' in x:
        x = x.replace('array([', '[')
        x = x.replace('])', ']')
    try:
        output = eval(x)
    except:
        output = x
    return output

def transpose(matrix):
    return [list(row) for row in zip(*matrix)]



def deduplicate(lst):
    seen = []
    deduped = []
    for item in lst:
        is_duplicate = False
        for seen_item in seen:
            if item == seen_item:
                is_duplicate = True
                break
        if not is_duplicate:
            seen.append(item)
            deduped.append(item)
    return deduped



parser = argparse.ArgumentParser()
parser.add_argument("-ds", "--dataset", dest="dataset", type=str, action="store", default='playgol_v2_ambig_from_llama8b')
parser.add_argument("-pf", "--prediction_file_path", dest="prediction_file_path", type=str, action="store", default='outputs/ambig/train1test4/playgol_v2/DI_meta-llama/Meta-Llama-3.1-8B-Instruct/ATC_4a_8s_1t.json')
parser.add_argument("-exp", "--experiment", dest="experiment", type=str, action="store", default='DiscriminationByExample')

parser.add_argument("-mi", "--model_id", dest="model_id", type=str, action="store", default='gpt-4.1-2025-04-14')

parser.add_argument("-cr", "--criteria", dest="criteria", type=str, action="store", default='maximin')
parser.add_argument("-s", "--sample_num", dest="sample_num", type=int, action="store", default=1)
parser.add_argument("-temp", "--temperature", dest="temperature", type=float, action="store", default=0.7)
args = parser.parse_args()


DATASET = args.dataset
TEMP = args.temperature if args.temperature != int(args.temperature) else int(args.temperature)
# HC
SAMPLE_NUM = args.sample_num
# CHC and SHTC
NUM_SAMPLE_PER_CONCEPT = args.sample_num
MODEL_ID = args.model_id

if 'playgol' in args.dataset:
    NUM_TRAIN = 1
    NUM_TEST = 4
    NUM_UNSEEN_TEST = 0
elif 'mbpp' in args.dataset:
    NUM_TRAIN = 1
    NUM_TEST = 10
    NUM_UNSEEN_TEST = 10

if args.experiment == 'DiscriminationByExample':
    METHOD = f'DiscriminationByExample_{SAMPLE_NUM}s_{TEMP}t'
    model = DiscriminationByExample(model_id=MODEL_ID, task=DATASET)
elif args.experiment == 'DiscriminationByExampleAndPrompt':
    METHOD = f'DiscriminationByExampleAndPrompt_{SAMPLE_NUM}s_{TEMP}t'
    model = DiscriminationByExampleAndPrompt(model_id=MODEL_ID, task=DATASET)




OUTPUT_PATH = f'outputs/ambig/train{NUM_TRAIN}test{NUM_TEST}/{DATASET}/{MODEL_ID}/{METHOD}_{args.criteria}.json'
os.makedirs(f'outputs/ambig/train{NUM_TRAIN}test{NUM_TEST}/{DATASET}/{MODEL_ID}', exist_ok=True)
print(OUTPUT_PATH)

def get_python_input(input):
    return ast.literal_eval(input)

with open(f'data/{DATASET}.jsonl', 'r') as f:
    data = [json.loads(x) for x in list(f)]
    # random.shuffle(data)
    # data = data[:3]

with open(args.prediction_file_path, 'r') as f:
    preds = json.load(f)


output_list = []
accuracy = 0
mean_example_level_accuracy = 0
mean_unseen_accuracy = 0
for problem, pred in zip(tqdm(data), preds):
    if 'mbpp' in DATASET:
        problem['train'] = ast.literal_eval(problem['train'])
        try:
            problem['test'] = ast.literal_eval(problem['test'])
        except:
            problem['test'] = eval(problem['test'])
    io_pairs = problem['train'] + problem['test']
    train_io_pairs = io_pairs[:NUM_TRAIN]
    test_io_pairs = io_pairs[NUM_TRAIN:NUM_TRAIN + NUM_TEST]
    test_inputs = [e['input'] for e in test_io_pairs]
    gt_outputs = [e['output'] for e in io_pairs[:NUM_TRAIN + NUM_TEST]]
    gt_outputs_for_unseen_test = [e['output'] for e in io_pairs[-NUM_UNSEEN_TEST:]]


    unique_execution_results = [custom_eval(p) for p in pred['unique_execution_results']] # [num_hypo, num_input]
    if NUM_UNSEEN_TEST > 0:
        execution_results_for_unseen_test = transpose(transpose(unique_execution_results)[-NUM_UNSEEN_TEST:])
    execution_results = transpose(transpose(unique_execution_results)[:NUM_TRAIN + NUM_TEST])
    # filter hypotheses that do not satisfy train io pairs
    passed_hypos = []
    for hypo in execution_results:
        # if all(er == io['output'] for io, er in zip(train_io_pairs, hypo)) and 'Error' not in str(hypo):
        if all(er == io['output'] for io, er in zip(train_io_pairs, hypo)):
            passed_hypos.append(hypo)

    passed_hypos = deduplicate(passed_hypos) # if we sliced inputs, then may be duplicate hypos generated


    num_calls = 0
    call_history = []
    while(len(passed_hypos) != 1):
        start_hypo_num = len(passed_hypos)

        unique_execution_results = transpose(passed_hypos) # [num_input, num_hypo]
        if args.criteria == 'maximin':
            min_values = []
            for idx, x in enumerate(unique_execution_results):
                candidates = deduplicate(x)
                values = []
                for cand in candidates:
                    elim = [o for o in x if o != cand]
                    values.append(len(elim))
                try:
                    min_value = min(values)
                except:
                    pdb.set_trace()
                min_values.append((min_value, idx, x))
            min_values = sorted(min_values, key=lambda x: -len(str(x[2]))) # tie breaker that favors shorter output
            try:
                max_value, max_idx, max_x = sorted(min_values, key=lambda x: x[0])[-1]
            except:
                pdb.set_trace()
            selected_idx = max_idx
            candidates = deduplicate(max_x)

        elif args.criteria == 'random':
            ambig_inputs = []
            for idx, x in enumerate(unique_execution_results):
                candidates = deduplicate(x)
                if len(candidates) > 1:
                    ambig_inputs.append((idx, x))

            selected_idx = random.randrange(len(ambig_inputs))
            selected_idx, candidates = ambig_inputs[selected_idx]
            candidates = deduplicate(candidates)

        
        if args.experiment == 'DiscriminationByExample':
            model_output = model.forward([(train_io_pairs, test_inputs[selected_idx-NUM_TRAIN], candidates)], temperature=TEMP)
        elif args.experiment == 'DiscriminationByExampleAndPrompt':
            model_output = model.forward([(problem['prompt'], train_io_pairs, test_inputs[selected_idx-NUM_TRAIN], candidates)], temperature=TEMP)
        num_calls += 1
        try:
            answer = model_output[0]['answer'][-1]
        except:
            pdb.set_trace()
            
            
        # select answer based on fuzzy matching
        max_score = 0
        for hypo in passed_hypos:
            score = SM(None, str(hypo[selected_idx]), answer).ratio()
            if score >= max_score:
                max_score = score
                selected_answer = str(hypo[selected_idx])

        # eliminate hypos inconsistent with answer
        new_passed_hypos = []
        for hypo in passed_hypos:
            if str(hypo[selected_idx]) == selected_answer:
                new_passed_hypos.append(hypo)

        if len(passed_hypos) == len(new_passed_hypos):
            pdb.set_trace()
        passed_hypos = new_passed_hypos




        call_history.append({
            'iter': num_calls-1,
            'selected_test_input': str(test_inputs[selected_idx-NUM_TRAIN]),
            'candidates': str(candidates),
            'answer': answer,
            'gt_answer': str(gt_outputs[selected_idx]),
            'num_hypos_change': f'{start_hypo_num} -> {len(passed_hypos)}'
        })


    
    if passed_hypos[0] == gt_outputs:
        correct = 1
    else:
        correct = 0
    accuracy += correct

    if NUM_UNSEEN_TEST > 0:
        final_consistent_ids = [i for (i, exe) in enumerate(execution_results) if exe == passed_hypos[0]]
        final_consistent_id = random.choice(final_consistent_ids)
        unseen_test_exe = execution_results_for_unseen_test[final_consistent_id]
        if unseen_test_exe == gt_outputs_for_unseen_test:
            unseen_accuracy = 1
        else:
            unseen_accuracy = 0

        mean_unseen_accuracy += unseen_accuracy


    matching_ratio = lambda a, b: sum(x == y for x, y in zip(a, b)) / len(a)
    example_level_accuracy = matching_ratio(passed_hypos[0], gt_outputs)
    mean_example_level_accuracy += example_level_accuracy
    
    if 'mbpp' in DATASET:
        task = problem['prompt']
    else:
        task = problem['train'][:NUM_TRAIN]

    if NUM_UNSEEN_TEST > 0:
        output_list.append(
            {
                'idx': problem['idx'],
                'correct': correct,
                'example_level_accuracy': example_level_accuracy,
                'task': task,
                'num_calls': num_calls,
                'final_hypo_exe_result': str(passed_hypos[0]),
                'gt_output': str(gt_outputs),
                'unseen_test_exe': str(unseen_test_exe),
                'gt_outputs_for_unseen_test': str(gt_outputs_for_unseen_test),
                'call_history': call_history,
            }
        )
    else:
        output_list.append(
            {
                'idx': problem['idx'],
                'correct': correct,
                'example_level_accuracy': example_level_accuracy,
                'task': task,
                'num_calls': num_calls,
                'final_hypo_exe_result': str(passed_hypos[0]),
                'gt_output': str(gt_outputs),
                'call_history': call_history,
            }
        )


print(f'TASK LEVEL ACCURACY: {accuracy/len(data)}')
print(f'EXAMPLE LEVEL ACCURACY: {mean_example_level_accuracy/len(data)}')
if NUM_UNSEEN_TEST > 0:
    print(f'UNSEEN TEST SET ACCURACY: {mean_unseen_accuracy/len(data)}')


with open(OUTPUT_PATH, 'w') as f:
    json.dump(output_list, f, indent=4)

