import json
import random
import numpy as np
import os
import argparse


parser = argparse.ArgumentParser()
parser.add_argument("-dsp", "--dataset", dest="dataset", type=str, action="store", default='playgol_v2')
parser.add_argument("-dsop", "--dataset_output_path", dest="dataset_output_path", type=str, action="store", default='playgol_v2_ambig')
parser.add_argument("-pp", "--prediction_path", dest="prediction_path", type=str, action="store", default='outputs/inductive/train1test4/playgol_v2/DI_meta-llama/Meta-Llama-3.1-8B-Instruct/ATC_4a_8s_1t.json')
parser.add_argument("-pop", "--prediction_output_path", dest="prediction_output_path", type=str, action="store", default='outputs/ambig/train1test4/playgol_v2/DI_meta-llama/Meta-Llama-3.1-8B-Instruct/ATC_4a_8s_1t.json')

args = parser.parse_args()

os.makedirs(os.path.dirname(args.dataset_output_path), exist_ok=True)
os.makedirs(os.path.dirname(args.prediction_output_path), exist_ok=True)



# Load your JSON data
with open(args.prediction_path, 'r') as f:
    preds = json.load(f)

selected_idxs = []
train_correct = 0
genenralization_success_ratios = []
all_codes_generalize = []

for d in preds:
    train_accs = d.get("train_accuracies", [])
    test_accs = d.get("test_accuracies", [])

    # Get all indices where train accuracy is 1
    candidate_idxs = [i for i, val in enumerate(train_accs) if val == 1]

    if not candidate_idxs:
        continue  # Skip if no correct train candidates
    train_correct += 1

    test_accs_of_candidate_idxs = [test_accs[i]//1 for i in candidate_idxs]
    genenralization_success_ratios.append(np.array(test_accs_of_candidate_idxs).mean())

    if any(test_accs_of_candidate_idxs) and not all(test_accs_of_candidate_idxs):
        selected_idxs.append(d['idx'])
    
    if all(test_accs_of_candidate_idxs):
        all_codes_generalize.append(d['idx'])

genenralization_success_ratio = float(np.array(genenralization_success_ratios).sum()/len(preds))
print(f'num_selected: {len(selected_idxs)}')

print(f'train_acc = 1 ratio: {train_correct}')
print(f'ambig_problem_ratio: {len(selected_idxs)}')
print(f'genenralization_success_ratio: {genenralization_success_ratio:.3f}')

print(f'performance_lower_bound: {len(all_codes_generalize)}')
print(f'performance_upper_bound: {len(all_codes_generalize)+len(selected_idxs)}')


with open(f'data/{args.dataset}.jsonl') as f:
    data = [json.loads(line) for line in f]

with open(args.dataset_output_path, 'w', encoding='utf-8') as out:
    for entry in data:
        if entry['idx'] in selected_idxs:
            out.write(json.dumps(entry) + '\n')

selected_problems = []
for entry in preds:
    if entry['idx'] in selected_idxs:
        selected_problems.append(entry)

with open(args.prediction_output_path, 'w', encoding='utf-8') as out:
    json.dump(selected_problems, out, indent=4)