import json
from random import sample
import pandas as pd
import numpy as np
from tqdm import tqdm
import random

def create_one_sample(a,b):
    c = a + b
    sample_string = f'{a} + {b} = {c}\n'
    input_string = f'{a} + {b} = '
    output_string = f'{c}\n'
    return {
        'sample_string': sample_string,
        'input': input_string,
        'output': output_string,
    }

def create_range(left = (0,99), right = (0,99), skip_range=[], shuffle = True, random_sample=-1):
    assert len(left) == 2 and len(right) == 2
    if not skip_range:
        skip_range = []
    for each in skip_range:
        assert len(each) == 4

    def in_skip(a, b):
        for each in skip_range:
            if a >= each[0] and a <= each[1] and b >= each[2] and b <= each[3]:
                return True
        return False

    all_dataset = []
    if random_sample > 0:
        import random
        i_s = [random.randint(left[0], left[1]) for _ in range(random_sample)]
        j_s = [random.randint(right[0], right[1]) for _ in range(random_sample)]
        for i, j in tqdm(zip(i_s, j_s)):
            if in_skip(i, j):
                continue
            all_dataset.append(create_one_sample(i,j))
    else:
        for i in tqdm(range(left[0], left[1])):
            for j in range(right[0], right[1]):
                if in_skip(i, j):
                    continue
                all_dataset.append(create_one_sample(i,j))

    if shuffle:
        import random
        random.shuffle(all_dataset)
    return all_dataset



def create_skip_range(skip_range, shuffle = True):
    all_dataset = []
    for each in skip_range:
        left = each[:2]
        right = each[2:]
        each_set = create_range(left, right, [])
        all_dataset.extend(each_set)
    if shuffle:
        import random
        random.shuffle(all_dataset)
    return all_dataset

def write_jsonl(l, p):
    with open(p, 'w') as f:
        for e in l:
            f.write(json.dumps(e) + "\n")

def gen_filename(prefix, left, right, skip_range):
    base_path = "data/"
    return base_path + prefix + "-left-" + str(left) + "-right-" + str(right) + "-skip-" + str(skip_range) + ".jsonl"

def four_digit_with_hole():
    left = (0,99)
    right = (0, 99)
    skip_range = [(10,20, 80,90), [30,40,60,70]]
    train_set = create_range(left, right, skip_range)
    test_set = create_skip_range(skip_range)
    write_jsonl(train_set, gen_filename("train", left, right, skip_range))
    write_jsonl(test_set, gen_filename("test", left, right, skip_range))

def train3_test45():
    left = (0,999)
    right = (0, 999)
    skip_range = []
    train_set = create_range(left, right, skip_range)

    left = (1000,9999)
    right = (1000, 9999)
    skip_range = []
    test4_set = create_range(left, right, skip_range, random_sample=10000)


    left = (10000,99999)
    right = (10000, 99999)
    skip_range = []
    test5_set = create_range(left, right, skip_range,random_sample=10000)

    left = (100000,999999)
    right = (100000, 999999)
    skip_range = []
    test6_set = create_range(left, right, skip_range,random_sample=10000)


    left = (1000000,9999999)
    right = (1000000, 9999999)
    skip_range = []
    test7_set = create_range(left, right, skip_range,random_sample=10000)

    write_jsonl(train_set, gen_filename("train_full_3", left, right, skip_range))
    write_jsonl(test4_set, gen_filename("test_1w_4", left, right, skip_range))
    write_jsonl(test5_set, gen_filename("test_1w_5", left, right, skip_range))
    write_jsonl(test6_set, gen_filename("test_1w_6", left, right, skip_range))
    write_jsonl(test7_set, gen_filename("test_1w_7", left, right, skip_range))

if __name__ == "__main__":
    four_digit_with_hole()






