import argparse
import time
import json
import os
from .utils import load_jsonl, save_jsonl
from .parser import run_execute, choice_answer_clean, parse_ground_truth
from .grader import math_equal

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_name", default="gsm8k", type=str)
    parser.add_argument("--prompt_type", default="tool-integrated", type=str)
    parser.add_argument("--input_folder", type=str, required=True)
    args = parser.parse_args()
    return args

def is_multi_choice(answer):
    for c in answer:
        if c not in ["A", "B", "C", "D", "E"]:
            return False
    return True

def parse_prediction(response, gt, data_name):
    result = run_execute(response, data_name)
    pred = result[0]
    if gt in ["A", "B", "C", "D", "E"] and pred not in ["A", "B", "C", "D", "E"]:
        pred = choice_answer_clean(pred)
    elif is_multi_choice(gt) and not is_multi_choice(pred):
        pred = "".join([c for c in pred if c in ["A", "B", "C", "D", "E"]])
    return pred


# def prepare_data(data_name, input_file, output_file):
#     # Load the processed samples
#     processed_samples = list(load_jsonl(input_file))  # Convert generator to list

#     # Ensure processed_samples is a list and has the correct structure
#     required_keys = {'idx', 'question', 'code'}

#     print("processed_samples: ", processed_samples)
#     if not isinstance(processed_samples, list):
#         raise ValueError(f"processed_samples is not a list")

#     for sample in processed_samples:
#         if not isinstance(sample, dict) or not required_keys.issubset(sample.keys()):
#             raise ValueError(f"The loaded data does not match the expected structure. Missing keys in sample: {sample}")

#         # Parse the ground truth using the parse_ground_truth function
#         sample['gt_cot'], sample['gt'] = parse_ground_truth(sample, data_name)
        
#         # Initialize the list to hold predictions
#         preds = []

#         if not isinstance(sample['code'], list):
#             sample['code'] = [sample['code']]

#         # Iterate through each element in 'code' to create corresponding 'pred'
#         for code in sample['code']:
#             result = run_execute(code, data_name)
#             pred = result[0]

#             if sample["gt"] in ["A", "B", "C", "D", "E"] and pred not in ["A", "B", "C", "D", "E"]:
#                 pred = choice_answer_clean(code)
#             elif is_multi_choice(sample["gt"]) and not is_multi_choice(pred):
#                 pred = "".join([c for c in pred if c in ["A", "B", "C", "D", "E"]])
            
#             # Append the parsed prediction to the 'preds' list
#             preds.append(pred)
        
#         # Store the list of parsed predictions
#         sample['pred'] = preds

#     # Save the updated samples to the new output JSONL file
#     save_jsonl(processed_samples, output_file)
    
#     return processed_samples

# def process_jsonl_files_in_folder(data_name, input_folder):
#     # Use os.walk to traverse all subdirectories and files
#     for dirpath, _, filenames in os.walk(input_folder):
#         for filename in filenames:
#             if filename.endswith('_results.jsonl'):
#                 input_file = os.path.join(dirpath, filename)
#                 output_file = input_file.replace(".jsonl", "_updated.jsonl")

#                 processed_samples = prepare_data(data_name, input_file, output_file)

#                 # Start timing the evaluation
#                 start_time = time.time()

#                 # Evaluate the loaded samples
#                 all_samples, result_json = evaluate(
#                     samples=processed_samples,
#                     data_name=data_name,
#                 )

#                 save_jsonl(all_samples, output_file)

#                 result_json["time_use_in_second"] = time.time() - start_time

#                 with open(
#                     output_file.replace(".jsonl", f"_metrics.json"), "w"
#                 ) as f:
#                     json.dump(result_json, f, indent=4)

#                 print(f"Processed: {input_file}")
#                 print(f"Results saved to: {output_file}")

# if __name__ == "__main__":
#     args = parse_args()
#     process_jsonl_files_in_folder(args.data_name, args.input_folder)
