import argparse
from math_verify import parse, verify
import json
from tqdm import tqdm
import os
import logging


BUCKETS = [
    [0],
    [1],
    [2],
    [3],
    [4],
    [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
]

def read_json_objects(filename, field_names=None):
    file_extension = os.path.splitext(filename)[1]
    if file_extension == '.jsonl':
        try:
            with open(filename, 'r') as file:
                lines = file.readlines()
            items = []
            for line in lines:
                item = json.loads(line)
                if field_names is not None and isinstance(field_names, list):
                    new_item = {}
                    for field_name in item:
                        new_item[field_name] = item[field_name]
                items.append(item)
            return items
        except FileNotFoundError:
            logging.error("The file was not found.")
        except json.JSONDecodeError:
            logging.error("There was an error decoding the JSONL file.")
        except Exception as e:
            logging.error(f"An error occurred: {e}")
    elif file_extension == '.json':
        try:
            with open(filename, 'r') as file:
                data = json.load(file)
            items = []
            for item in data:
                items.append(item)
            return items
        except FileNotFoundError:
            logging.error("The file was not found.")
        except json.JSONDecodeError:
            logging.error("There was an error decoding the JSON file.")
        except Exception as e:
            logging.error(f"An error occurred: {e}")
    else:
        logging.error(f"Unknown file extension {file_extension}")
        return []


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--ablation', action='store_true', help='whether to create ablation datasets')
    parser.add_argument('--val_number', action='store', type=int, default=100, help='number of val groups')
    parser.add_argument('--seed', action='store', type=int, default=42, help='random seed for splitting')
    args = parser.parse_args()


    # --------------- Utilities -----------------

    # create min_steps -> difficulty mapping
    step_to_difficulty = {}
    for difficulty, steps in enumerate(BUCKETS):
        for step in steps:
            step_to_difficulty[step] = difficulty
            
    def write_jsonl(filename, data_list):
        with open(filename, 'w') as f:
            for item in data_list:
                f.write(json.dumps(item) + '\n')
            
    instruction = ''
    # instruction = "Solve the following question step-by-step."
    def convert_to_prompt_completion(data_list):
        result = []
        for item in data_list:
            question = item.get("question", "")
            answer = item.get('answer', "")
            reasoning = item.get('reasoning', "")
            min_steps = item.get("min_steps", -1)
            prompt_text = instruction + f"Q: {question}" + "\nA: Let's think step by step."
            completion_text = f"{reasoning}\nThe answer is {answer}."
            datum = {
                "prompt": prompt_text,
                "completion": completion_text,
                "answer": answer,
                "reasoning": reasoning,
                "min_steps": min_steps,
                "difficulty": step_to_difficulty.get(min_steps, -1)
            }
            result.append(datum)
        return result
    
    def compute_stats(data_list):
        step_count = {}
        difficulty_count = {}
        total_count = 0
        for item in data_list:
            min_steps = item.get("min_steps", -1)
            difficulty = item.get("difficulty", -1)
            if min_steps not in step_count:
                step_count[min_steps] = 0
            step_count[min_steps] += 1
            if difficulty not in difficulty_count:
                difficulty_count[difficulty] = 0
            difficulty_count[difficulty] += 1
            total_count += 1
        return step_count, difficulty_count, total_count

    def write_stats(filename, step_count, difficulty_count, total_count, prefix="", mode='a'):
        with open(filename, mode) as f:
            f.write(f"{prefix} Total Count: {total_count}\n")
            f.write(f"{prefix} Step Count:\n")
            for step in sorted(step_count.keys()):
                f.write(f"  Steps {step}: {step_count[step]}\n")
            f.write(f"{prefix} Difficulty Count:\n")
            for difficulty in sorted(difficulty_count.keys()):
                f.write(f"  Difficulty {difficulty}: {difficulty_count[difficulty]}\n")

    
    # ------------------ Process Data ----------------------
    print("Processing GSM8K data...")
    
    # make sure training_data/gsm8k directory exists
    os.makedirs('training_data/gsm8k', exist_ok=True)

    # flat_data_list = read_json_objects('processed_data/gsm8k/rewrites.flat.jsonl')
    grouped_data_list = read_json_objects('processed_data/gsm8k/rewrites.grouped.jsonl')

    # check if all answers are strings
    for i, item in enumerate(grouped_data_list):
        for j, q in enumerate(item.get("questions", [])):
            answer = q.get("answer", "")
            if not isinstance(answer, str):
                print(f"Warning: answer is not a string in group {i} question {j}: {answer}")

    # shuffle the grouped data
    import random
    random.seed(args.seed)
    random.shuffle(grouped_data_list)
    
    # TODO: load the test set questions here if they are also measured by GPT-5
    
    # select args.val_number groups that have >= len(BUCKETS) versions
    # as validation groups
    val_group_ids = set()
    for item in grouped_data_list:
        group_id = item.get("id", "")
        questions = item.get("versions", [])
        if len(questions) >= len(BUCKETS):
            val_group_ids.add(group_id)
        if len(val_group_ids) >= args.val_number:
            break
    print(f"Selected {len(val_group_ids)} groups for validation set.")
    
    # create train and val splits based on group ids
    # and flatten train and val splits
    # seperate original and rewrites (all versions except version 1)
    train_list = []
    val_list = []
    train_original_list = []
    train_rewrite_list = []
    for item in grouped_data_list:
        group_id = item.get("id", "")
        questions = item.get("versions", [])
        if group_id in val_group_ids:
            val_list.extend(questions)
        else:
            train_list.extend(questions)
            train_original_list.append(questions[0])  # assuming version 1 is the first in the list
            train_rewrite_list.extend(questions[1:])  # all versions except version 1
            
    # convert train and val splits to prompt-completion format
    train_list = convert_to_prompt_completion(train_list)
    val_list = convert_to_prompt_completion(val_list)
    train_original_list = convert_to_prompt_completion(train_original_list)
    train_rewrite_list = convert_to_prompt_completion(train_rewrite_list)
    print(f"Train size: {len(train_list)}, Val size: {len(val_list)}, Train-Original size: {len(train_original_list)}, Train-rewrite size: {len(train_rewrite_list)}.")

    # write train and val and original splits to files
    write_jsonl('training_data/gsm8k/train.jsonl', train_list)
    write_jsonl('training_data/gsm8k/val.jsonl', val_list)
    write_jsonl('training_data/gsm8k/train.original.jsonl', train_original_list)
    write_jsonl('training_data/gsm8k/train.rewrite.jsonl', train_rewrite_list)
            
    # write stats about train and val splits
    train_step_count, train_difficulty_count, train_total_count = compute_stats(train_list)
    val_step_count, val_difficulty_count, val_total_count = compute_stats(val_list)
    original_step_count, original_difficulty_count, original_total_count = compute_stats(train_original_list)
    rewrite_step_count, rewrite_difficulty_count, rewrite_total_count = compute_stats(train_rewrite_list)
    write_stats('training_data/gsm8k/stats.txt', train_step_count, train_difficulty_count, train_total_count, prefix="Train Data", mode='w')
    write_stats('training_data/gsm8k/stats.txt', val_step_count, val_difficulty_count, val_total_count, prefix="Val Data", mode='a')
    write_stats('training_data/gsm8k/stats.txt', original_step_count, original_difficulty_count, original_total_count, prefix="Train-Original Data", mode='a')
    write_stats('training_data/gsm8k/stats.txt', rewrite_step_count, rewrite_difficulty_count, rewrite_total_count, prefix="Train-Rewrite Data", mode='a')

    # -------------- Create Ablation Dataset -----------------
    if args.ablation:
        # 1. create ablation of the rewrite depth
        os.makedirs('training_data/gsm8k/depth', exist_ok=True)
        train_group_ids = set()
        for item in grouped_data_list:
            group_id = item.get("id", "")
            if group_id not in val_group_ids:
                train_group_ids.add(group_id)
                
        import collections
        depth_map = collections.defaultdict(list)
        # depth 0 means no rewrite, use the original question (version 1)
        # depth 1 means use the first rewrite (version 2), etc.
        # so depth d = version - 1
        # add all versions to their corresponding depth

        # only use the groups in the training set
        for item in grouped_data_list:
            group_id = item.get("id", "")
            if group_id not in train_group_ids:
                continue
            questions = item.get("versions", [])
            for v, q in enumerate(questions):
                depth = v  # version v corresponds to depth v-1, but v starts from 0
                depth_map[depth].append(q)
        
        # now depth_map contains all questions for each depth
        # depth_map[0] contains all original questions (version 1)
        # depth_map[1] contains all first rewrites (version 2), etc.
        assert len(depth_map.keys()) > 3, "Expected at least 4 depths in the data."

        # write the stats of each depth
        # count depth 4 and above together
        for depth in range(4):
            depth_list = depth_map.get(depth, [])
            step_count, difficulty_count, total_count = compute_stats(convert_to_prompt_completion(depth_list))
            write_stats(f'training_data/gsm8k/depth/stats.txt', step_count, difficulty_count, total_count, prefix=f"Depth {depth} Data", mode='a' if depth > 0 else 'w')
            print(f"Depth {depth} has {len(depth_list)} examples.")
        # combine all depths >= 4
        combined_depth_4_plus = []
        for d in range(4, max(depth_map.keys()) + 1):
            combined_depth_4_plus.extend(depth_map.get(d, []))
        step_count, difficulty_count, total_count = compute_stats(convert_to_prompt_completion(combined_depth_4_plus))
        write_stats(f'training_data/gsm8k/depth/stats.txt', step_count, difficulty_count, total_count, prefix=f"Depth 4+ Data", mode='a')
        print(f"Depth 4+ has {len(combined_depth_4_plus)} examples.")

        # write depth 0, 1, 2, 3 to separate files
        # each one contains all questions with that depth and below
        # write depth 4+ to a single file
        # depth 4 file contains all questions with version >= 5
        # write each depth to a separate jsonl file
        for depth in range(0, 4):
            combined_depth_list = []
            for d in range(0, depth + 1):
                combined_depth_list.extend(depth_map.get(d, []))
            combined_depth_list = convert_to_prompt_completion(combined_depth_list)
            write_jsonl(f'training_data/gsm8k/depth/rewrites.depth{depth}.jsonl', combined_depth_list)
            step_count, difficulty_count, total_count = compute_stats(combined_depth_list)
            write_stats(f'training_data/gsm8k/depth/stats.txt', step_count, difficulty_count, total_count, prefix=f"Depth {depth} Data", mode='a' if depth > 0 else 'w')
            print(f"Wrote depth {depth} data with {len(combined_depth_list)} examples.")

        # now include depth 4+, it is just all train rewrite questions
        combined_depth_list = []
        for d in range(0, max(depth_map.keys()) + 1):
            combined_depth_list.extend(depth_map.get(d, []))
        combined_depth_list = convert_to_prompt_completion(combined_depth_list)
        write_jsonl(f'training_data/gsm8k/depth/rewrites.depth4plus.jsonl', combined_depth_list)
        step_count, difficulty_count, total_count = compute_stats(combined_depth_list)
        write_stats(f'training_data/gsm8k/depth/stats.txt', step_count, difficulty_count, total_count, prefix=f"Depth 4+ Data", mode='a')
        print(f"Wrote depth 4+ data with {len(combined_depth_list)} examples.")


        # 2. create ablation of the proportion of rewrite questions
        os.makedirs('training_data/gsm8k/proportion', exist_ok=True)
        proportions = [0.0, 0.25, 0.5, 0.75, 1.0]

        # for each proportion, create a dataset with that proportion of rewrites
        # shuffle the train_rewrites_list
        random.shuffle(train_rewrite_list)
        for prop in proportions:
            if prop == 0.0:
                ablation_list = train_original_list
            elif prop == 1.0:
                ablation_list = train_original_list + train_rewrite_list
            else:
                num_rewrites = int(prop * len(train_rewrite_list))
                num_originals = len(train_original_list)
                num_rewrites = min(num_rewrites, len(train_rewrite_list))
                ablation_list = train_original_list + train_rewrite_list[:num_rewrites]
                random.shuffle(ablation_list)

            ablation_list = convert_to_prompt_completion(ablation_list)
            prop_str = str(int(prop * 100))
            write_jsonl(f'training_data/gsm8k/proportion/rewrites.prop{prop_str}.jsonl', ablation_list)
            step_count, difficulty_count, total_count = compute_stats(ablation_list)
            write_stats(f'training_data/gsm8k/proportion/stats.txt', step_count, difficulty_count, total_count, prefix=f"Proportion {prop} Data", mode='a' if prop > 0.0 else 'w')
            print(f"Wrote proportion {prop} data with {len(ablation_list)} examples.")


    print("Processing of GSM8K data complete.")
    
if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    main()