import argparse
from math_verify import parse, verify
import json
from tqdm import tqdm
import os
import logging



BUCKETS = [
    [0],
    [1],
    [2],
    [3, 4],
    [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
]

def read_json_objects(filename, field_names=None):
    file_extension = os.path.splitext(filename)[1]
    if file_extension == '.jsonl':
        try:
            with open(filename, 'r') as file:
                lines = file.readlines()
            items = []
            for line in lines:
                item = json.loads(line)
                if field_names is not None and isinstance(field_names, list):
                    new_item = {}
                    for field_name in item:
                        new_item[field_name] = item[field_name]
                items.append(item)
            return items
        except FileNotFoundError:
            logging.error("The file was not found.")
        except json.JSONDecodeError:
            logging.error("There was an error decoding the JSONL file.")
        except Exception as e:
            logging.error(f"An error occurred: {e}")
    elif file_extension == '.json':
        try:
            with open(filename, 'r') as file:
                data = json.load(file)
            items = []
            for item in data:
                items.append(item)
            return items
        except FileNotFoundError:
            logging.error("The file was not found.")
        except json.JSONDecodeError:
            logging.error("There was an error decoding the JSON file.")
        except Exception as e:
            logging.error(f"An error occurred: {e}")
    else:
        logging.error(f"Unknown file extension {file_extension}")
        return []


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--ablation', action='store_true', help='whether to create ablation datasets')
    parser.add_argument('--val_number', action='store', type=int, default=100, help='number of val groups')
    parser.add_argument('--seed', action='store', type=int, default=42, help='random seed for splitting')
    args = parser.parse_args()


    # ------------------ Utilities ----------------------
    
    # create min_steps -> difficulty mapping
    step_to_difficulty = {}
    for difficulty, steps_list in enumerate(BUCKETS):
        for steps in steps_list:
            step_to_difficulty[steps] = difficulty
            
    def write_jsonl(filename, data_list):
        with open(filename, 'w') as f:
            for item in data_list:
                f.write(json.dumps(item) + '\n')
    
    # instruction = flat_data_list[0].get("instruction", "") # use for answer before reasoning
    instruction = "Answer the question based on the context below. sent is the shorthand for sentence and int means intermediate conclusion. Give the reasoning step and then summarize the answer in one sentence."
    def convert_to_prompt_completion(data_list):
        result = []
        for item in data_list:
            input = item.get("input", "")
            answer = item.get('answer', "")
            reasoning = item.get('reasoning', "")
            min_steps = item.get("min_steps", -1)
            prompt_text = f"{instruction}\n{input}"
            completion_text = f"{reasoning}\nThe answer is: {answer}"
            datum = {
                "prompt": prompt_text,
                "completion": completion_text,
                "answer": answer,
                "reasoning": reasoning,
                "min_steps": min_steps,
                "difficulty": step_to_difficulty.get(min_steps, -1)
            }
            result.append(datum)
        return result

    def compute_stats(data_list):
        step_count = {}
        difficulty_count = {}
        total_count = 0
        for item in data_list:
            min_steps = item.get("min_steps", -1)
            difficulty = item.get("difficulty", -1)
            if min_steps not in step_count:
                step_count[min_steps] = 0
            step_count[min_steps] += 1
            if difficulty not in difficulty_count:
                difficulty_count[difficulty] = 0
            difficulty_count[difficulty] += 1
            total_count += 1
        return step_count, difficulty_count, total_count

    def write_stats(filename, step_count, difficulty_count, total_count, prefix="", mode='a'):
        with open(filename, mode) as f:
            f.write(f"{prefix} Total Count: {total_count}\n")
            f.write(f"{prefix} Step Count:\n")
            for step in sorted(step_count.keys()):
                f.write(f"  Steps {step}: {step_count[step]}\n")
            f.write(f"{prefix} Difficulty Count:\n")
            for difficulty in sorted(difficulty_count.keys()):
                f.write(f"  Difficulty {difficulty}: {difficulty_count[difficulty]}\n")


    # ------------------ Process Data ----------------------
    print("Processing Entailment Bank Data...")
        
    # make sure training_data/entailment-bank directory exists
    os.makedirs('training_data/entailment-bank', exist_ok=True)   
         
    # flat_data_list = read_json_objects('processed_data/entailment-bank/rewrites.flat.jsonl')
    grouped_data_list = read_json_objects('processed_data/entailment-bank/rewrites.grouped.jsonl')
    print(f"[DEBUG] Total groups: {len(grouped_data_list)}")
    
    # shuffle the grouped data
    import random
    random.seed(args.seed)
    random.shuffle(grouped_data_list)
    
    # set aside some questions as test set
    test_size = 100
    test_groups = grouped_data_list[-test_size:]
    print(f"[DEBUG] Test groups: {len(test_groups)}")
    
    # IMPORTANT: remove test groups from grouped_data_list
    grouped_data_list = grouped_data_list[:-test_size]
    print(f"[DEBUG] Train groups: {len(grouped_data_list)}")
    
    # only leave the original questions in the test set
    test_list = []
    for item in test_groups:
        questions = item.get("versions", [])
        if len(questions) > 0:
            test_list.append(questions[0])  # assuming version 1 is the first in the list
    print(f"Selected {len(test_list)} groups for test set.")

    # select args.val_number groups that have >= len(BUCKETS) versions
    # as validation groups
    val_group_ids = set()
    for item in grouped_data_list:
        group_id = item.get("id", "")
        questions = item.get("versions", [])
        if len(questions) >= len(BUCKETS):
            val_group_ids.add(group_id)
        if len(val_group_ids) >= args.val_number:
            break
    print(f"Selected {len(val_group_ids)} groups for validation set.")
    
    # create train and val splits based on group ids
    # and flatten train and val splits
    # also create an original list, taking version 1 from each group
    train_list = []
    val_list = []
    train_original_list = []
    train_rewrite_list = []
    for item in grouped_data_list:
        group_id = item.get("id", "")
        questions = item.get("versions", [])
        if group_id in val_group_ids:
            val_list.extend(questions)
        else:
            train_list.extend(questions)
            train_original_list.append(questions[0])  # assuming version 1 is the first in the list
            train_rewrite_list.extend(questions[1:])  # all other versions are rewrites

    # convert train, val, original lists to prompt-completion format
    train_list = convert_to_prompt_completion(train_list)
    val_list = convert_to_prompt_completion(val_list)
    test_list = convert_to_prompt_completion(test_list)
    train_original_list = convert_to_prompt_completion(train_original_list)
    train_rewrite_list = convert_to_prompt_completion(train_rewrite_list)
    print(f"Train size: {len(train_list)}, Val size: {len(val_list)}, Test size: {len(test_list)}, Train-original size: {len(train_original_list)}, Train-rewrite size: {len(train_rewrite_list)}")

    # write train and val and original splits to files
    write_jsonl('training_data/entailment-bank/train.jsonl', train_list)
    write_jsonl('training_data/entailment-bank/val.jsonl', val_list)
    write_jsonl('training_data/entailment-bank/test.jsonl', test_list)
    write_jsonl('training_data/entailment-bank/train.original.jsonl', train_original_list)
    write_jsonl('training_data/entailment-bank/train.rewrite.jsonl', train_rewrite_list)
    
    # write stats about train, val, test, train.original, train.rewrite
    train_step_count, train_difficulty_count, train_total_count = compute_stats(train_list)
    val_step_count, val_difficulty_count, val_total_count = compute_stats(val_list)
    test_step_count, test_difficulty_count, test_total_count = compute_stats(test_list)
    original_step_count, original_difficulty_count, original_total_count = compute_stats(train_original_list)
    rewrite_step_count, rewrite_difficulty_count, rewrite_total_count = compute_stats(train_rewrite_list)
    write_stats('training_data/entailment-bank/stats.txt', train_step_count, train_difficulty_count, train_total_count, prefix="Train Data", mode='w')
    write_stats('training_data/entailment-bank/stats.txt', val_step_count, val_difficulty_count, val_total_count, prefix="Val Data", mode='a')
    write_stats('training_data/entailment-bank/stats.txt', test_step_count, test_difficulty_count, test_total_count, prefix="Test Data", mode='a')
    write_stats('training_data/entailment-bank/stats.txt', original_step_count, original_difficulty_count, original_total_count, prefix="Train-Original Data", mode='a')
    write_stats('training_data/entailment-bank/stats.txt', rewrite_step_count, rewrite_difficulty_count, rewrite_total_count, prefix="Train-Rewrite Data", mode='a')


    # -------------- Create Ablation Dataset -----------------
    if args.ablation:
        # 1. create ablation of the rewrite depth
        os.makedirs('training_data/entailment-bank/depth', exist_ok=True)
        train_group_ids = set()
        for item in grouped_data_list:
            group_id = item.get("id", "")
            if group_id not in val_group_ids:
                train_group_ids.add(group_id)
                
        import collections
        depth_map = collections.defaultdict(list)
        # depth 0 means no rewrite, use the original question (version 1)
        # depth 1 means use the first rewrite (version 2), etc.
        # so depth d = version - 1
        # add all versions to their corresponding depth

        # only use the groups in the training set
        for item in grouped_data_list:
            group_id = item.get("id", "")
            if group_id not in train_group_ids:
                continue
            questions = item.get("versions", [])
            for v, q in enumerate(questions):
                depth = v  # version v corresponds to depth v-1, but v starts from 0
                depth_map[depth].append(q)
        
        # now depth_map contains all questions for each depth
        # depth_map[0] contains all original questions (version 1)
        # depth_map[1] contains all first rewrites (version 2), etc.
        assert len(depth_map.keys()) > 3, "Expected at least 4 depths in the data."
        
        # write the stats of each depth
        # count depth 4 and above together
        for depth in range(4):
            depth_list = depth_map.get(depth, [])
            step_count, difficulty_count, total_count = compute_stats(convert_to_prompt_completion(depth_list))
            write_stats(f'training_data/entailment-bank/depth/stats.txt', step_count, difficulty_count, total_count, prefix=f"Depth {depth} Data", mode='a' if depth > 0 else 'w')
            print(f"Depth {depth} has {len(depth_list)} examples.")
        # combine all depths >= 4
        combined_depth_4_plus = []
        for d in range(4, max(depth_map.keys()) + 1):
            combined_depth_4_plus.extend(depth_map.get(d, []))
        step_count, difficulty_count, total_count = compute_stats(convert_to_prompt_completion(combined_depth_4_plus))
        write_stats(f'training_data/entailment-bank/depth/stats.txt', step_count, difficulty_count, total_count, prefix=f"Depth 4+ Data", mode='a')
        print(f"Depth 4+ has {len(combined_depth_4_plus)} examples.")

        # write depth 0, 1, 2, 3 to separate files
        # each one contains all questions with that depth and below
        # write depth 4+ to a single file
        # depth 4 file contains all questions with version >= 5
        # write each depth to a separate jsonl file
        for depth in range(0, 4):
            combined_depth_list = []
            for d in range(0, depth + 1):
                combined_depth_list.extend(depth_map.get(d, []))
            combined_depth_list = convert_to_prompt_completion(combined_depth_list)
            write_jsonl(f'training_data/entailment-bank/depth/rewrites.depth{depth}.jsonl', combined_depth_list)
            step_count, difficulty_count, total_count = compute_stats(combined_depth_list)
            write_stats(f'training_data/entailment-bank/depth/stats.txt', step_count, difficulty_count, total_count, prefix=f"Depth {depth} Data", mode='a' if depth > 0 else 'w')
            print(f"Wrote depth {depth} data with {len(combined_depth_list)} examples.")

        # now include depth 4+, it is just all train rewrite questions
        combined_depth_list = []
        for d in range(0, max(depth_map.keys()) + 1):
            combined_depth_list.extend(depth_map.get(d, []))
        combined_depth_list = convert_to_prompt_completion(combined_depth_list)
        write_jsonl(f'training_data/entailment-bank/depth/rewrites.depth4plus.jsonl', combined_depth_list)
        step_count, difficulty_count, total_count = compute_stats(combined_depth_list)
        write_stats(f'training_data/entailment-bank/depth/stats.txt', step_count, difficulty_count, total_count, prefix=f"Depth 4+ Data", mode='a')
        print(f"Wrote depth 4+ data with {len(combined_depth_list)} examples.")


        # 2. create ablation of the proportion of rewrite questions
        os.makedirs('training_data/entailment-bank/proportion', exist_ok=True)
        proportions = [0.0, 0.25, 0.5, 0.75, 1.0]

        # for each proportion, create a dataset with that proportion of rewrites
        # shuffle the train_rewrites_list
        random.shuffle(train_rewrite_list)
        for prop in proportions:
            if prop == 0.0:
                ablation_list = train_original_list
            elif prop == 1.0:
                ablation_list = train_original_list + train_rewrite_list
            else:
                num_rewrites = int(prop * len(train_rewrite_list))
                num_originals = len(train_original_list)
                num_rewrites = min(num_rewrites, len(train_rewrite_list))
                ablation_list = train_original_list + train_rewrite_list[:num_rewrites]
                random.shuffle(ablation_list)

            ablation_list = convert_to_prompt_completion(ablation_list)
            prop_str = str(int(prop * 100))
            write_jsonl(f'training_data/entailment-bank/proportion/rewrites.prop{prop_str}.jsonl', ablation_list)
            step_count, difficulty_count, total_count = compute_stats(ablation_list)
            write_stats(f'training_data/entailment-bank/proportion/stats.txt', step_count, difficulty_count, total_count, prefix=f"Proportion {prop} Data", mode='a' if prop > 0.0 else 'w')
            print(f"Wrote proportion {prop} data with {len(ablation_list)} examples.")


    print("Processing of entailment bank data complete.")

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    main()
