import random
import sys
from kcm import *
import os
from itertools import product
import math
from tqdm import tqdm  # <--- NEW

def training_set_mul_generator(max_num=10000, r=100):
    auto = build_mul_kcm()
    alphabet = ['a', 'b', 'c']
    def gen_samples(n, min_len, max_len):
        samples = set()
        while len(samples) < n:
            length = random.randint(min_len, max_len)
            s = ''.join(random.choices(alphabet, k=length))
            if s != '':
                samples.add(s)
        return list(samples)

    ics = [i for i in range(r)]
    jcs = [j for j in range(r)]
    kcs = [k for k in range(r)]
    cand = list(product(ics, jcs, kcs))
    tmp = []
    # for i, j, k in cand:
    #     if i + j + k <= r and i * j == k:
    #         s = 'a'*i + 'b'*j + 'c'*k
    #         if s not in tmp and s != '':
    #             tmp.append(s)

    tmp = (gen_samples(max_num, 0, r))
    len_train = int(0.8 * len(tmp))
    train_inputs = tmp[:len_train]
    val0_inputs = tmp[len_train:]

    # add progress bars here
    train_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(train_inputs, desc="mul-train")]
    val0_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val0_inputs, desc="mul-val0")]

    val1_inputs = gen_samples(len(val0_inputs), 100, 200)
    val2_inputs = gen_samples(len(val0_inputs), 200, 300)
    val1_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val1_inputs, desc="mul-val1")]
    val2_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val2_inputs, desc="mul-val2")]

    os.makedirs("dataset/mul", exist_ok=True)
    with open("dataset/mul/input.txt", "w") as f:
        for s in train_inputs:
            f.write(s+'\n')
    with open("dataset/mul/target.txt", "w") as f:
        for s in train_outputs:
            f.write(s+'\n')
    with open("dataset/mul/input_val0.txt", "w") as f:
        for s in val0_inputs:
            f.write(s+'\n')
    with open("dataset/mul/target_val0.txt", "w") as f:
        for s in val0_outputs:
            f.write(s+'\n')
    with open("dataset/mul/input_val1.txt", "w") as f:
        for s in val1_inputs:
            f.write(s+'\n')
    with open("dataset/mul/target_val1.txt", "w") as f:
        for s in val1_outputs:
            f.write(s+'\n')
    with open("dataset/mul/input_val2.txt", "w") as f:
        for s in val2_inputs:
            f.write(s+'\n')
    with open("dataset/mul/target_val2.txt", "w") as f:
        for s in val2_outputs:
            f.write(s+'\n')

def training_set_dvd_generator(max_num):
    auto = build_dvd_kcm()
    alphabet = ['a', 'b']
    def gen_samples(n, min_len, max_len):
        samples = set()
        while len(samples) < n:
            length = random.randint(min_len, max_len)
            s = ''.join(random.choices(alphabet, k=length))
            if s != '':
                samples.add(s)
        return list(samples)
    ics = [i for i in range(100)]
    jcs = [j for j in range(1, 100)]
    cand = list(product(ics, jcs))
    random.shuffle(cand)
    # train_inputs = []
    tmp = []
    for i, j in cand:
        # avoid division/modulo by zero when i == 0
        if i != 0 and i + j <= 100 and j % i == 0:
            s = 'a'*i + 'b'*j
            if s not in tmp and s != '' and len(tmp) < max_num:
                tmp.append(s)
                 
    # tmp1 = tmp
    tmp1 = (gen_samples(max_num, 0, 100))
    # random.shuffle(tmp)
    num_of_training = int(len(tmp1)*0.8)
    train_inputs = tmp1[:num_of_training]
    val0_inputs = tmp1[num_of_training:len(tmp1)]
    
    ics = [i for i in range(5)]
    jcs = [j for j in range(1,200)]
    cand = list(product(ics, jcs))
    random.shuffle(cand)
    # train_inputs = []
    tmp = []
    for i, j in cand:
        # avoid division/modulo by zero when i == 0
        if i != 0 and i + j <= 200 and i + j > 100 and j % i == 0:
            s = 'a'*i + 'b'*j
            if s not in tmp and s != '' and len(tmp) < max_num:
                tmp.append(s)
                
                
    
    val1_inputs = gen_samples(len(val0_inputs), 100, 200)
    # val1_inputs.extend(tmp)
    ics = [i for i in range(300)]
    jcs = [j for j in range(1, 300)]
    cand = list(product(ics, jcs))
    random.shuffle(cand)
    # train_inputs = []
    tmp = []
    for i, j in cand:
        # avoid division/modulo by zero when i == 0
        if i != 0 and i + j <= 300 and i + j > 200 and j % i == 0:
            s = 'a'*i + 'b'*j
            if s not in tmp and s != '' and len(tmp) < max_num:
                tmp.append(s)
    
    val2_inputs = gen_samples(len(val0_inputs), 200, 300)
    # val2_inputs.extend(tmp)
    train_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(train_inputs, desc="dvd-train")]
    val0_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val0_inputs, desc="dvd-val0")]

    # val1_inputs = gen_samples(len(val0_inputs), 100, 200)
    # val2_inputs = gen_samples(len(val0_inputs), 200, 300)
    val1_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val1_inputs, desc="dvd-val1")]
    val2_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val2_inputs, desc="dvd-val2")]

    os.makedirs("dataset/dvd", exist_ok=True)
    with open("dataset/dvd/input.txt", "w") as f:
        for s in train_inputs:
            f.write(s+'\n')
    with open("dataset/dvd/target.txt", "w") as f:
        for s in train_outputs:
            f.write(s+'\n')
    with open("dataset/dvd/input_val0.txt", "w") as f:
        for s in val0_inputs:
            f.write(s+'\n')
    with open("dataset/dvd/target_val0.txt", "w") as f:
        for s in val0_outputs:
            f.write(s+'\n')
    with open("dataset/dvd/input_val1.txt", "w") as f:
        for s in val1_inputs:
            f.write(s+'\n')
    with open("dataset/dvd/target_val1.txt", "w") as f:
        for s in val1_outputs:
            f.write(s+'\n')
    with open("dataset/dvd/input_val2.txt", "w") as f:
        for s in val2_inputs:
            f.write(s+'\n')
    with open("dataset/dvd/target_val2.txt", "w") as f:
        for s in val2_outputs:
            f.write(s+'\n')

def training_set_gcd_generator(max_num):
    auto = build_gcd_kcm()
    alphabet = ['a', 'b', 'c']
    def gen_samples(n, min_len, max_len):
        samples = set()
        
        while len(samples) < n:
            length = random.randint(min_len, max_len)
            s = ''.join(random.choices(alphabet, k=length))
            if s != '':
                samples.add(s)
        return list(samples)

    ics = [i for i in range(100)]
    jcs = [j for j in range(100)]
    kcs = [k for k in range(100)]
    cand = list(product(ics, jcs, kcs))
    train_inputs = []
    tmp = []
    for i, j, k in cand:
        if i + j + k <= 100 and k == math.gcd(i, j):
            s = 'a'*i + 'b'*j + 'c'*k
            if s not in tmp and s != '':
                tmp.append(s) 

    tmp = (gen_samples(max_num, 0, 100))
    len_train = int(0.8*len(tmp))
    train_inputs = tmp[:len_train]
    val0_inputs = tmp[len_train:len(tmp)]
    train_outputs = [auto.output_generator(seq, alphabet) for seq in train_inputs]
    val1_inputs = gen_samples(len(val0_inputs), 100, 200)
    val2_inputs = gen_samples(len(val0_inputs), 200, 300)
    train_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(train_inputs, desc="gcd-train")]
    val0_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val0_inputs, desc="gcd-val0")]

    # val1_inputs = gen_samples(len(val0_inputs), 100, 200)
    # val2_inputs = gen_samples(len(val0_inputs), 200, 300)
    val1_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val1_inputs, desc="gcd-val1")]
    val2_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val2_inputs, desc="gcd-val2")]

    os.makedirs("dataset/gcd", exist_ok=True)
    with open("dataset/gcd/input.txt", "w") as f:
        for s in train_inputs:
            f.write(s+'\n')
    with open("dataset/gcd/target.txt", "w") as f:
        for s in train_outputs:
            f.write(s+'\n')
    with open("dataset/gcd/input_val0.txt", "w") as f:
        for s in val0_inputs:
            f.write(s+'\n')
    with open("dataset/gcd/target_val0.txt", "w") as f:
        for s in val0_outputs:
            f.write(s+'\n')
    with open("dataset/gcd/input_val1.txt", "w") as f:
        for s in val1_inputs:
            f.write(s+'\n')
    with open("dataset/gcd/target_val1.txt", "w") as f:
        for s in val1_outputs:
            f.write(s+'\n')
    with open("dataset/gcd/input_val2.txt", "w") as f:
        for s in val2_inputs:
            f.write(s+'\n')
    with open("dataset/gcd/target_val2.txt", "w") as f:
        for s in val2_outputs:
            f.write(s+'\n')

def training_set_prime_generator(max_num):
    auto = build_prime_unary_kcm()
    alphabet = ['a', 'b']
    def gen_samples(n, min_len, max_len):
        samples = set()
        while len(samples) < n:
            length = random.randint(min_len, max_len)
            s = ''.join(random.choices(alphabet, k=length))
            if s != '':
                samples.add(s)
        return list(samples)
    ics = [i for i in range(100)]
    jcs = [j for j in range(100)]
    cand = list(product(ics, jcs))
    random.shuffle(cand)
    train_inputs = []
    tmp = []
    for i in ics:
        s = 'a'*i
        if s not in tmp and s != '':
            tmp.append(s)

    # for i, j in cand:
    #     if i + j <= 100:
    #         s = 'a'*i + 'b'*j
    #         if s not in tmp and s != '':
    #             tmp.append(s)

    tmp.extend(gen_samples(max_num, 0, 100))
    len_train = int(0.8*len(tmp))
    train_inputs = tmp[:len_train]
    val0_inputs = tmp[len_train:]
    
    ics = [i for i in range(100,200)]
    cand = list(product(ics, jcs))
    # random.shuffle(cand)
    tmp = []
    for i in ics:     
        s = 'a'*i  
        if s not in tmp and s != '':
            tmp.append(s)

    tmp.extend(gen_samples(len(val0_inputs)-len(tmp), 100, 200))
    val1_inputs = tmp
    
    ics = [i for i in range(200,300)]
    cand = list(product(ics, jcs))
    # random.shuffle(cand)
    tmp = []
    for i in ics:
        
        s = 'a'*i
        if s not in tmp and s != '':
            tmp.append(s)
       
    tmp.extend(gen_samples(len(val0_inputs), 200, 300))
    val2_inputs = tmp
    
    # val2_inputs = gen_samples(len(val0_inputs), 200, 300)
    train_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(train_inputs, desc="prime-train")]
    val0_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val0_inputs, desc="prime-val0")]

    # val1_inputs = gen_samples(len(val0_inputs), 100, 200)
    # val2_inputs = gen_samples(len(val0_inputs), 200, 300)
    val1_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val1_inputs, desc="prime-val1")]
    val2_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val2_inputs, desc="prime-val2")]
    
    
    os.makedirs("dataset/prime", exist_ok=True)
    with open("dataset/prime/input.txt", "w") as f:
        for s in train_inputs:
            f.write(s+'\n')
    with open("dataset/prime/target.txt", "w") as f:
        for s in train_outputs:
            f.write(s+'\n')
    with open("dataset/prime/input_val0.txt", "w") as f:
        for s in val0_inputs:
            f.write(s+'\n')
    with open("dataset/prime/target_val0.txt", "w") as f:
        for s in val0_outputs:
            f.write(s+'\n')
    with open("dataset/prime/input_val1.txt", "w") as f:
        for s in val1_inputs:
            f.write(s+'\n')
    with open("dataset/prime/target_val1.txt", "w") as f:
        for s in val1_outputs:
            f.write(s+'\n')
    with open("dataset/prime/input_val2.txt", "w") as f:
        for s in val2_inputs:
            f.write(s+'\n')
    with open("dataset/prime/target_val2.txt", "w") as f:
        for s in val2_outputs:
            f.write(s+'\n')


def training_set_exp_generator(max_num):
    auto = build_exp_kcm()
    alphabet = ['a', 'b']
    def gen_samples(n, min_len, max_len):
        samples = set()
        while len(samples) < n:
            length = random.randint(min_len, max_len)
            s = ''.join(random.choices(alphabet, k=length))
            if s != '':
                samples.add(s)
        return list(samples)
    ics = [i for i in range(100)]
    jcs = [j for j in range(100)]
    cand = list(product(ics, jcs))
    random.shuffle(cand)
    train_inputs = []
    tmp = []
    for i in ics:
        s = 'a'*i
        if s not in tmp and s != '':
            tmp.append(s)


    tmp = (gen_samples(max_num, 0, 100))
    len_train = int(0.8*len(tmp))
    train_inputs = tmp[:len_train]
    val0_inputs = tmp[len_train:]
    
    # ics = [i for i in range(100,200)]
    # cand = list(product(ics, jcs))
    # # random.shuffle(cand)
    # tmp = []
    # for i in ics:     
    #     s = 'a'*i  
    #     if s not in tmp and s != '':
    #         tmp.append(s)

    tmp.extend(gen_samples(len(val0_inputs), 100, 200))
    val1_inputs = tmp
    
    # ics = [i for i in range(200,300)]
    # cand = list(product(ics, jcs))
    # # random.shuffle(cand)
    # tmp = []
    # for i in ics:
    #     s = 'a'*i
    #     if s not in tmp and s != '':
    #         tmp.append(s)
       
    tmp.extend(gen_samples(len(val0_inputs), 200, 300))
    val2_inputs = tmp
    
    # val2_inputs = gen_samples(len(val0_inputs), 200, 300)
    train_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(train_inputs, desc="exp-train")]
    val0_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val0_inputs, desc="exp-val0")]

    # val1_inputs = gen_samples(len(val0_inputs), 100, 200)
    # val2_inputs = gen_samples(len(val0_inputs), 200, 300)
    val1_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val1_inputs, desc="exp-val1")]
    val2_outputs = [auto.output_generator(seq, alphabet) for seq in tqdm(val2_inputs, desc="exp-val2")]
    
    
    os.makedirs("dataset/exp", exist_ok=True)
    with open("dataset/exp/input.txt", "w") as f:
        for s in train_inputs:
            f.write(s+'\n')
    with open("dataset/exp/target.txt", "w") as f:
        for s in train_outputs:
            f.write(s+'\n')
    with open("dataset/exp/input_val0.txt", "w") as f:
        for s in val0_inputs:
            f.write(s+'\n')
    with open("dataset/exp/target_val0.txt", "w") as f:
        for s in val0_outputs:
            f.write(s+'\n')
    with open("dataset/exp/input_val1.txt", "w") as f:
        for s in val1_inputs:
            f.write(s+'\n')
    with open("dataset/exp/target_val1.txt", "w") as f:
        for s in val1_outputs:
            f.write(s+'\n')
    with open("dataset/exp/input_val2.txt", "w") as f:
        for s in val2_inputs:
            f.write(s+'\n')
    with open("dataset/exp/target_val2.txt", "w") as f:
        for s in val2_outputs:
            f.write(s+'\n')




if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Generate dataset for mul or dvd tasks.")
    parser.add_argument('--task', choices=['mul', 'dvd', 'prime', 'gcd', 'exp'], required=True, help='Task type: mul, dvd, prime, gcd, exp')
    parser.add_argument('--max_num', type=int, default=1000, help='Maximum number of samples')
    parser.add_argument('--r', type=int, default=200, help='Range for a, b, c')
    args = parser.parse_args()

    if args.task == 'mul':
        training_set_mul_generator(args.max_num)
    elif args.task == 'dvd':
        training_set_dvd_generator(args.max_num)
    elif args.task == 'prime':
        training_set_prime_generator(args.max_num)
    elif args.task == 'gcd':
        training_set_gcd_generator(args.max_num)
    elif args.task == 'exp':
        training_set_exp_generator(args.max_num)