import random
import sys
from unary.kcm import *
import os
from itertools import product
import math
from tqdm import tqdm  # <--- NEW

def training_set_mul_generator(max_num=10000, r=100):
    auto = build_mul_kcm()
    alphabet = ['a', 'b', 'c']
    def gen_samples(n, min_len, max_len):
        samples = set()
        while len(samples) < n:
            length = random.randint(min_len, max_len)
            s = ''.join(random.choices(alphabet, k=length))
            if s != '':
                samples.add(s)
        return list(samples)

    ics = [i for i in range(r)]
    jcs = [j for j in range(r)]
    kcs = [k for k in range(r)]
    cand = list(product(ics, jcs, kcs))
    tmp = []
    for i, j, k in cand:
        if i + j + k <= r and i * j == k:
            s = 'a'*i + 'b'*j + 'c'*k
            if s not in tmp and s != '':
                tmp.append(s)

    tmp = (gen_samples(max_num, 0, r))
    len_train = int(0.8 * len(tmp))
    train_inputs = tmp[:len_train]
    val0_inputs = tmp[len_train:]

    # add progress bars here
    train_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(train_inputs, desc="mul-train")]
    val0_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val0_inputs, desc="mul-val0")]

    val1_inputs = gen_samples(len(val0_inputs), 100, 200)
    val2_inputs = gen_samples(len(val0_inputs), 200, 300)
    val1_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val1_inputs, desc="mul-val1")]
    val2_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val2_inputs, desc="mul-val2")]

    os.makedirs("dataset/mul", exist_ok=True)
    with open("dataset/mul/input.txt", "w") as f:
        for s in train_inputs:
            f.write(s+'\n')
    with open("dataset/mul/target.txt", "w") as f:
        for s in train_outputs:
            f.write(s+'\n')
    with open("dataset/mul/input_val0.txt", "w") as f:
        for s in val0_inputs:
            f.write(s+'\n')
    with open("dataset/mul/target_val0.txt", "w") as f:
        for s in val0_outputs:
            f.write(s+'\n')
    with open("dataset/mul/input_val1.txt", "w") as f:
        for s in val1_inputs:
            f.write(s+'\n')
    with open("dataset/mul/target_val1.txt", "w") as f:
        for s in val1_outputs:
            f.write(s+'\n')
    with open("dataset/mul/input_val2.txt", "w") as f:
        for s in val2_inputs:
            f.write(s+'\n')
    with open("dataset/mul/target_val2.txt", "w") as f:
        for s in val2_outputs:
            f.write(s+'\n')

def training_set_dvd_generator(max_num):
    auto = build_dvd_kcm()
    alphabet = ['a', 'b']
    def gen_samples(n, min_len, max_len):
        samples = set()
        
        while len(samples) < n:
            length = random.randint(min_len, max_len)
            s = ''.join(random.choices(alphabet, k=length))
            if s != '':
                samples.add(s)
        return list(samples)
    # ics = [i for i in range(100)]
    # jcs = [j for j in range(100)]
    # cand = list(product(ics, jcs))
    # random.shuffle(cand)
    # train_inputs = []
    tmp = []
    # for i, j in cand:
    #     if i + j <= 100:
    #         s = 'a'*i + 'b'*j
    #         if s not in tmp and s != '':
    #             tmp.append(s)
                 

    tmp = (gen_samples(max_num-len(tmp), 0, 100))
    # random.shuffle(tmp)
    num_of_training = int(len(tmp)*0.8)
    train_inputs = tmp[:num_of_training]
    val0_inputs = tmp[num_of_training:len(tmp)]
    
    val1_inputs = gen_samples(len(val0_inputs), 100, 200)
    val2_inputs = gen_samples(len(val0_inputs), 200, 300)
    train_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(train_inputs, desc="dvd-train")]
    val0_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val0_inputs, desc="dvd-val0")]

    # val1_inputs = gen_samples(len(val0_inputs), 100, 200)
    # val2_inputs = gen_samples(len(val0_inputs), 200, 300)
    val1_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val1_inputs, desc="dvd-val1")]
    val2_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val2_inputs, desc="dvd-val2")]

    os.makedirs("dataset/dvd", exist_ok=True)
    with open("dataset/dvd/input.txt", "w") as f:
        for s in train_inputs:
            f.write(s+'\n')
    with open("dataset/dvd/target.txt", "w") as f:
        for s in train_outputs:
            f.write(s+'\n')
    with open("dataset/dvd/input_val0.txt", "w") as f:
        for s in val0_inputs:
            f.write(s+'\n')
    with open("dataset/dvd/target_val0.txt", "w") as f:
        for s in val0_outputs:
            f.write(s+'\n')
    with open("dataset/dvd/input_val1.txt", "w") as f:
        for s in val1_inputs:
            f.write(s+'\n')
    with open("dataset/dvd/target_val1.txt", "w") as f:
        for s in val1_outputs:
            f.write(s+'\n')
    with open("dataset/dvd/input_val2.txt", "w") as f:
        for s in val2_inputs:
            f.write(s+'\n')
    with open("dataset/dvd/target_val2.txt", "w") as f:
        for s in val2_outputs:
            f.write(s+'\n')

def training_set_gcd_generator(max_num):
    auto = build_gcd_kcm()
    alphabet = ['a', 'b', 'c']
    def gen_samples(n, min_len, max_len):
        samples = set()
        
        while len(samples) < n:
            length = random.randint(min_len, max_len)
            s = ''.join(random.choices(alphabet, k=length))
            if s != '':
                samples.add(s)
        return list(samples)

    ics = [i for i in range(100)]
    jcs = [j for j in range(100)]
    kcs = [k for k in range(100)]
    cand = list(product(ics, jcs, kcs))
    train_inputs = []
    tmp = []
    for i, j, k in cand:
        if i + j + k <= 100 and k == math.gcd(i, j):
            s = 'a'*i + 'b'*j + 'c'*k
            if s not in tmp and s != '':
                tmp.append(s) 

    tmp = (gen_samples(max_num, 0, 100))
    len_train = int(0.8*len(tmp))
    train_inputs = tmp[:len_train]
    val0_inputs = tmp[len_train:len(tmp)]
    train_outputs = [auto.output_generator(seq, alphabet) for seq in train_inputs]
    val1_inputs = gen_samples(len(val0_inputs), 100, 200)
    val2_inputs = gen_samples(len(val0_inputs), 200, 300)
    train_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(train_inputs, desc="gcd-train")]
    val0_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val0_inputs, desc="gcd-val0")]

    # val1_inputs = gen_samples(len(val0_inputs), 100, 200)
    # val2_inputs = gen_samples(len(val0_inputs), 200, 300)
    val1_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val1_inputs, desc="gcd-val1")]
    val2_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val2_inputs, desc="gcd-val2")]

    os.makedirs("dataset/gcd", exist_ok=True)
    with open("dataset/gcd/input.txt", "w") as f:
        for s in train_inputs:
            f.write(s+'\n')
    with open("dataset/gcd/target.txt", "w") as f:
        for s in train_outputs:
            f.write(s+'\n')
    with open("dataset/gcd/input_val0.txt", "w") as f:
        for s in val0_inputs:
            f.write(s+'\n')
    with open("dataset/gcd/target_val0.txt", "w") as f:
        for s in val0_outputs:
            f.write(s+'\n')
    with open("dataset/gcd/input_val1.txt", "w") as f:
        for s in val1_inputs:
            f.write(s+'\n')
    with open("dataset/gcd/target_val1.txt", "w") as f:
        for s in val1_outputs:
            f.write(s+'\n')
    with open("dataset/gcd/input_val2.txt", "w") as f:
        for s in val2_inputs:
            f.write(s+'\n')
    with open("dataset/gcd/target_val2.txt", "w") as f:
        for s in val2_outputs:
            f.write(s+'\n')





def training_set_prime_generator(max_num):
    auto = build_prime_unary_kcm()
    # print(auto.output_generator('abbaaa', ['a', 'b']))
    input_arr = []
    alphabet = ['a', 'b']
    num = 0
    ics = [i for i in range(200)]
    for i in ics:
        if num < 100:
            s = 'a'*i
            if s not in input_arr and s != '':
                input_arr.append(s)
                num += 1
                
    while num < max_num:
        length = random.randint(1, 200)
        tmp_str = ""
        for _ in range(length):
            tmp_str += random.choice(alphabet)
        if tmp_str not in input_arr and tmp_str != '':
            input_arr.append(tmp_str)
            num += 1

    print("Max length of input:", max(len(s) for s in input_arr))
    
    # Split training data: 80% for train, 20% for val0
    train_split = int(0.8*len(input_arr))
    train_input = input_arr[:train_split]
    val0_input = input_arr[train_split:]
    
    train_target = [(auto.output_generator(seq, alphabet)) for seq in train_input]
    prime_val0 = [(auto.output_generator(seq, alphabet)) for seq in val0_input]
    
    
    val1_input = []
    num = 0
    while num < len(prime_val0):
        length = random.randint(200, 300)
        tmp_str = ""
        for _ in range(length):
            tmp_str += random.choice(alphabet)
        if tmp_str not in val1_input and tmp_str != '':
            val1_input.append(tmp_str)
            num += 1
    prime_target1 = [(auto.output_generator(seq, alphabet)) for seq in val1_input]
    
    val2_input = []
    num = 0
    while num < len(prime_val0):
        length = random.randint(300, 400)
        tmp_str = ""
        for _ in range(length):
            tmp_str += random.choice(alphabet)
        if tmp_str not in val2_input and tmp_str != '':
            val2_input.append(tmp_str)
            num += 1
    prime_target2 = [(auto.output_generator(seq, alphabet)) for seq in val2_input]
        
        
        
    os.makedirs("dataset/prime", exist_ok=True)
    with open("dataset/prime/input.txt", "w") as f:
        for s in train_input:
            f.write(s+'\n')
    with open("dataset/prime/target.txt", "w") as f:
        for s in train_target:
            f.write(s+'\n')
            
    with open("dataset/prime/input_val0.txt", "w") as f:
        for s in val0_input:
            f.write(s+'\n')
    with open("dataset/prime/target_val0.txt", "w") as f:
        for s in prime_val0:
            f.write(s+'\n')
    
    with open("dataset/prime/input_val1.txt", "w") as f:
        for s in val1_input:
            f.write(s+'\n')
    with open("dataset/prime/target_val1.txt", "w") as f:
        for s in prime_target1:
            f.write(s+'\n')
    with open("dataset/prime/input_val2.txt", "w") as f:
        for s in val2_input:
            f.write(s+'\n')
    with open("dataset/prime/target_val2.txt", "w") as f:
        for s in prime_target2:
            f.write(s+'\n')

def training_set_exp_generator(max_num):
    auto = build_exp_kcm()
    alphabet = ['a', 'b']
    def gen_samples(n, min_len, max_len):
        samples = set()
        
        while len(samples) < n:
            length = random.randint(min_len, max_len)
            s = ''.join(random.choices(alphabet, k=length))
            if s != '':
                samples.add(s)
        return list(samples)
    ics = [i for i in range(100)]
    jcs = [j for j in range(100)]
    cand = list(product(ics, jcs))
    random.shuffle(cand)
    train_inputs = []
    tmp = []
    for i, j in cand:
        if i + j <= 100:
            s = 'a'*i + 'b'*j
            if s not in tmp and s != '':
                tmp.append(s)
                 

    tmp = (gen_samples(max_num, 0, 100))
    len_train = int(0.8*len(tmp))
    train_inputs = tmp[:len_train]
    val0_inputs = tmp[len_train:]
    
    ics = [i for i in range(0,200)]
    jcs = [j for j in range(0,200)]
    cand = list(product(ics, jcs))
    # random.shuffle(cand)
    tmp = []
    for i, j in cand:
        if i + j <= 200 and i + j > 100:
            s = 'a'*i + 'b'*j
            if s not in tmp and s != '':
                tmp.append(s)
        if len(tmp) >= len(val0_inputs)//2:
            break
    
    val1_inputs = (gen_samples(len(val0_inputs), 100, 200))
    # val1_inputs = tmp
    
    ics = [i for i in range(0,300)]
    jcs = [j for j in range(0,300)]
    cand = list(product(ics, jcs))
    # random.shuffle(cand)
    tmp = []
    for i, j in cand:
        if i + j <= 300 and i + j > 200:
            s = 'a'*i + 'b'*j
            if s not in tmp and s != '':
                tmp.append(s)
        if len(tmp) >= len(val0_inputs) // 2:
            break
    
    val2_inputs = (gen_samples(len(val0_inputs), 200, 300))
    # val2_inputs = tmp
    
    # val2_inputs = gen_samples(len(val0_inputs), 200, 300)
    train_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(train_inputs, desc="exp-train")]
    val0_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val0_inputs, desc="exp-val0")]

    # val1_inputs = gen_samples(len(val0_inputs), 100, 200)
    # val2_inputs = gen_samples(len(val0_inputs), 200, 300)
    val1_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val1_inputs, desc="exp-val1")]
    val2_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val2_inputs, desc="exp-val2")]
    
    
    os.makedirs("dataset/exp", exist_ok=True)
    with open("dataset/exp/input.txt", "w") as f:
        for s in train_inputs:
            f.write(s+'\n')
    with open("dataset/exp/target.txt", "w") as f:
        for s in train_outputs:
            f.write(s+'\n')
    with open("dataset/exp/input_val0.txt", "w") as f:
        for s in val0_inputs:
            f.write(s+'\n')
    with open("dataset/exp/target_val0.txt", "w") as f:
        for s in val0_outputs:
            f.write(s+'\n')
    with open("dataset/exp/input_val1.txt", "w") as f:
        for s in val1_inputs:
            f.write(s+'\n')
    with open("dataset/exp/target_val1.txt", "w") as f:
        for s in val1_outputs:
            f.write(s+'\n')
    with open("dataset/exp/input_val2.txt", "w") as f:
        for s in val2_inputs:
            f.write(s+'\n')
    with open("dataset/exp/target_val2.txt", "w") as f:
        for s in val2_outputs:
            f.write(s+'\n')




if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Generate dataset for mul or dvd tasks.")
    parser.add_argument('--task', choices=['mul', 'dvd', 'prime', 'gcd', 'exp'], required=True, help='Task type: mul, dvd, prime, gcd, exp')
    parser.add_argument('--max_num', type=int, default=1000, help='Maximum number of samples')
    parser.add_argument('--r', type=int, default=200, help='Range for a, b, c')
    args = parser.parse_args()

    if args.task == 'mul':
        training_set_mul_generator(args.max_num)
    elif args.task == 'dvd':
        training_set_dvd_generator(args.max_num)
    elif args.task == 'prime':
        training_set_prime_generator(args.max_num)
    elif args.task == 'gcd':
        training_set_gcd_generator(args.max_num)
    elif args.task == 'exp':
        training_set_exp_generator(args.max_num)