import random
import argparse
import pickle

def reproduce_long(x):
    return x

def reproduce_short(x):
    return x[0]+x[-1]

def reverse(x):
    return x[::-1]

def even_short(x):
    if (len(x) % 2) == 0:
        return x[0] + x[-1]
    else:
        return x[-1] + x[0]

def sort_long(x):
    return ''.join(sorted(x))

def sort_short(x):
    sorted_x = ''.join(sorted(x))
    return sorted_x[0] + sorted_x[-1]

def even_long(x):
    if (len(x) % 2) == 0:
        return ''.join([x[i] for i in range(len(x)) if (i % 2) == 1])
        #return x + x[0]
    else:
        return ''.join([x[i] for i in range(len(x)) if (i % 2) == 0])
        #return x + x[-1]

def reverse_long(x):
    return x[::-1]

def case_long(x):
    to_return = ""
    for x_i in x:
        if x_i.isalpha():
            if x_i.islower():
                to_return += x_i.upper()
            else:
                to_return += x_i.lower()
        else:
            to_return += x_i
    return to_return

def periodic_long(x):
    seq_length = 1
    smallest_found = 0
    while seq_length < len(x):
        if (len(x) % seq_length) == 0:
            if x == ''.join([x[:seq_length] for i in range(len(x)//seq_length)]):
                smallest_found = seq_length
                break
        seq_length += 1
    return x[:smallest_found]

class PrimeChecker(object):
    def __init__(self, max_val):
        self.primes = self.sieve(max_val)

    def sieve(self, max_val):
        primes = []
        sieve = [True for i in range(max_val + 1)]
        sieve[0] = None
        sieve[1] = None
        for i, is_prime in enumerate(sieve):
            if is_prime:
                primes.append(i)
                j = 2
                while i*j <= max_val:
                    sieve[i*j] = False
                    j += 1
            else:
                continue
        return set(primes)

    def prime_condition(self, x):
        if x in self.primes:
            return reproduce(x)
        else:
            return reverse(x)

def odd(x):
    return_str = ""
    count = {}
    for i, x_i in enumerate(x):
        if x_i not in count:
            count[x_i] = i
        else:
            del count[x_i]
    return ''.join(
        sorted([k for k in count.keys()], key=lambda x:count[x]))

def pal_short(x):
    i = 0
    while i < len(x) - i-1:
        if x[i] == x[len(x)-i-1]:
            i += 1
        else:
            break
    if i == len(x)-i-1:
        return x[i] + x[i+1]
    else:
        return x[len(x)-i-1] + x[len(x)-i]

def pal_long(x):
    i = 0
    while i < len(x) - i-1:
        if x[i] == x[len(x)-i-1]:
            i += 1
        else:
            break
    if i == len(x)-i-1:
        return x[:i] + x[i] +  x[i:]
    else:
        return x[:len(x)- i] + x[i] + x[len(x)-i:]

def well_formed_dyck(length, alphabet):
    if length < 2:
        return ""
    else:
        bracket = random.choice(alphabet)
        dyck = bracket + bracket
        idx = random.randint(0,2)
        substring = well_formed_dyck(length - 2, alphabet) 
        return dyck[:idx] + substring + dyck[idx:]

def is_dyck(x):
    stack = []
    for x_i in x:
        if len(stack) == 0 or stack[-1] != x_i:
            stack.append(x_i)
        else:
            stack = stack[:-1]
    return len(stack) == 0

def dyck_short(x):
    if is_dyck(x):
        return "aa"
    else:
        return "bb"

def dyck_long(x):
    if is_dyck(x):
        return x
    else:
        return fix_dyck(x)

def fix_dyck(x):
    stack = []
    for i, x_i in enumerate(x):
        if len(stack) == 0 or stack[-1][0] != x_i:
            stack.append([x_i, i])
        else:
            stack = stack[:-1]
    if len(stack) > 0:
        idx = len(stack)//2
        unpopped = stack[idx]
        return x[:unpopped[1]+1] + unpopped[0] + x[unpopped[1]+1:]

def gen_data(identify_mode, num_to_generate, max_len, max_seq_len, 
    max_dep_length, transformation_function, num_withheld=0, withheld_chars=[], 
    unseen_mode="full_alphabet", odd_number=-1, pal=False, dyck=False, 
    all_lower=False, period=False):
    nums = ["0"] + [str(j) for j in range(3, 10)] 
    alphabet = [chr(a) for a in range(65,65+26)] + [chr(a) for a in range(97,97+26)]
    withheld = []
    if len(withheld_chars) == 0:
        for i in range(args.num_withheld):
            to_withhold = random.choice(alphabet[1:])
            withheld.append(to_withhold)
            alphabet.remove(to_withhold)
    else:
        withheld = withheld_chars
    data = []
    for i in range(num_to_generate):
        if period:
            rep_seq_length = random.randint(1, max_seq_len//2)
            repeats = random.randint(2, max_seq_len//rep_seq_length)
            seq_length = repeats * rep_seq_length
        else:
            seq_length = random.randint(1, max_seq_len)
        dep_length = random.randint(0, max_dep_length)
        if max_len - seq_length - 3 - dep_length > 0:
            before_seq_length = random.randint(0, 
                max_len - seq_length - 3 - dep_length)
        else:
            before_seq_length = 0
        string = ""

        for j in range(before_seq_length):
            if identify_mode == "distractor":
                string += random.choice(alphabet)
            elif identify_mode == "padding":
                string += "0"

        string += "1"
        target_seq = ""
        if odd_number == -1 and not pal and not dyck and not period:
            for j in range(seq_length):
                if unseen_mode == "full_alphabet":
                    target_seq += random.choice(alphabet)
                elif unseen_mode == "unseen_only":
                    target_seq += random.choice(withheld_chars)
            if all_lower:
                target_seq = target_seq.lower()
        elif odd_number > 0 or odd_number == -2:
            if odd_number == -2:
                odd_number = seq_length // 2
            for j in range(odd_number):
                if unseen_mode == "full_alphabet":
                    gen = random.choice(alphabet)
                elif unseen_mode == "unseen_only":
                    gen = random.choice(withheld_chars)
                target_seq += gen
            even_seq_length = seq_length - len(target_seq)
            for j in range(even_seq_length//2):
                if unseen_mode == "full_alphabet":
                    gen = random.choice(alphabet)
                elif unseen_mode == "unseen_only":
                    gen = random.choice(withheld_chars)
                target_seq += gen + gen
            
            seq_list = list(target_seq)
            random.shuffle(seq_list)
            target_seq = ''.join(seq_list)
        elif pal:
            half_length = max(seq_length//2,2)
            first_half = ""
            for j in range(half_length):
                if unseen_mode == "full_alphabet":
                    first_half += random.choice(alphabet)
                elif unseen_mode == "unseen_only":
                    first_half += random.choice(withheld_chars)
            second_half = first_half[::-1]
            idx = random.randint(0, len(second_half)-2)
            second_half = second_half[:idx] + second_half[idx+1:]
            target_seq = first_half+second_half
        elif dyck:
            length = max(seq_length, 2)
            if unseen_mode == "full_alphabet":
                gen_set = alphabet
            elif unseen_mode == "unseen_only":
                gen_set = withheld_chars
            full_seq = well_formed_dyck(length, gen_set)
            if random.randint(0,1) == 0:
                target_seq = full_seq
            else:
                idx = random.randint(0, len(full_seq) - 1)
                target_seq = full_seq[:idx] + full_seq[idx+1:]
        elif period:
            repeat_seq = ""
            for j in range(rep_seq_length):
                if unseen_mode == "full_alphabet":
                    repeat_seq += random.choice(alphabet)
                elif unseen_mode == "unseen_only":
                    repeat_seq += random.choice(withheld_chars)
            target_seq += ''.join([repeat_seq for i in range(repeats)])
        string += target_seq + "1"

        for j in range(dep_length):
            if identify_mode == "distractor":
                string += random.choice(alphabet)
            elif identify_mode == "padding":
                string += "0"
        string += "2"

        data.append({"input":string, 
            "correct_output":transformation_function(target_seq)+"2", 
            "seq_length":seq_length, "total_length":len(string), 
            "dep_length_start":len(string) - before_seq_length, 
            "dep_length_end":dep_length})
    return data, withheld


# TODO: More transformations

parser = argparse.ArgumentParser()
parser.add_argument('--train_examples', type=int, default=10000, 
    help="Number of train examples to generate")
parser.add_argument('--val_examples', type=int, default=1500, 
    help="Number of val examples to generate")
parser.add_argument('--test_examples', type=int, default=1500, 
    help="Number of test examples to generate")
parser.add_argument('--mode', type=str, choices=['padding', "distractor"], 
    help="Which type of data to generate", default='distractor')
parser.add_argument('--max_train_length', type=int, 
    help="Maximum length of generated sequence in training", default=256)
parser.add_argument('--max_test_length', type=int, 
    help="Maximum length of generated sequence", default=512)
parser.add_argument('--max_train_answer_length', type=int, 
    help="Maximum length of sequence to be reproduced in training", default=128)
parser.add_argument('--max_test_answer_length', type=int, 
    help="Maximum length of sequence to be reproduced", default=256)
parser.add_argument('--max_train_dep_length', type=int, 
    help="Maximum length of dependency sequence in training", default=256)
parser.add_argument('--max_test_dep_length', type=int, 
    help="Maximum length of dependency sequence in testing", default=256)
parser.add_argument('--num_withheld', type=int, 
    help="Number of characters to withhold in training", default=0)
parser.add_argument('--unseen_mode', type=str, 
    choices=['full_alphabet', "unseen_only"], 
    help="Use query sequences with exclusively unseen characters", 
    default='full_alphabet')
parser.add_argument('--transformation_function', type=str, 
    choices=['reproduce_long', "even_long", "even_short", "sort_long", 
    "sort_short", "reproduce_short", "odd_short", "odd_long", "pal_short", 
    "pal_long", "dyck_short", "dyck_long", "reverse_long", "case_long", 
    "periodic_long"], 
    help="How to transform sequence of input", 
    default='reproduce')
parser.add_argument('--output_file', type=str,
    help="Where to save", default='')
args = parser.parse_args()

odd_number = -1
pal = False
dyck = False
all_lower = False
period=False
if args.transformation_function == "reproduce_long":
    transformation_function = reproduce_long
elif args.transformation_function == "reproduce_short":
    transformation_function = reproduce_short
elif args.transformation_function == "even_short":
    transformation_function = even_short
elif args.transformation_function == "even_long":
    transformation_function = even_long
elif args.transformation_function == "prime":
    prime_check = PrimeChecker(
        max(args.max_train_answer_length, args.max_test_answer_length))
    transformation_function = prime_check.prime_condition
elif args.transformation_function == "sort_long":
    transformation_function = sort_long
elif args.transformation_function == "sort_short":
    transformation_function = sort_short
elif args.transformation_function == "odd_short":
    transformation_function = odd
    odd_number = 2
elif args.transformation_function == "odd_long":
    transformation_function = odd
    odd_number = -2
elif args.transformation_function == "pal_short":
    transformation_function = pal_short
    pal = True
elif args.transformation_function == "pal_long":
    transformation_function = pal_long
    pal = True
elif args.transformation_function == "dyck_short":
    transformation_function = dyck_short
    dyck = True
elif args.transformation_function == "dyck_long":
    transformation_function = dyck_long
    dyck = True
elif args.transformation_function == "reverse_long":
    transformation_function = reverse_long
elif args.transformation_function == "case_long":
    transformation_function = case_long
    all_lower = True
elif args.transformation_function == "periodic_long":
    transformation_function = periodic_long
    period = True

train_data, withheld = gen_data(args.mode, args.train_examples, 
    args.max_train_length, args.max_train_answer_length, 
    args.max_train_dep_length, transformation_function, args.num_withheld, 
    odd_number=odd_number, pal=pal, dyck=dyck, all_lower=all_lower, period=period)
print(withheld)
print(max([len(v['input']) for v in train_data]))
val_data, _ = gen_data(args.mode, args.val_examples, 
    args.max_train_length, args.max_train_answer_length, 
    args.max_train_dep_length, transformation_function, args.num_withheld, 
    odd_number=odd_number, pal=pal, dyck=dyck,all_lower=all_lower, period=period)

test_data, _ = gen_data(args.mode, args.test_examples, 
    args.max_test_length, args.max_test_answer_length, 
    args.max_test_dep_length, transformation_function, args.num_withheld, 
    withheld, args.unseen_mode, odd_number=odd_number, pal=pal, dyck=dyck,all_lower=all_lower, period=period)

params = {'train_size':args.train_examples,'test_size':args.test_examples,
            'val_size':args.val_examples, 'mode':args.mode,
            'max_train_length':args.max_train_length, 
            'max_train_answer_length':args.max_train_answer_length, 
            'max_train_dep_length':args.max_train_dep_length,
            'max_test_length':args.max_test_length, 
            'max_test_answer_length':args.max_test_answer_length, 
            'max_test_dep_length':args.max_test_dep_length,
            'unseen_mode':args.unseen_mode,
            'transformation_function':args.transformation_function,
            'num_withheld':args.num_withheld}

data_dict = {'train': train_data, 'test': test_data, 'val':val_data, 
            'withheld':withheld, 'params':params}
if args.output_file != '':
    pickle.dump(data_dict, open(args.output_file, 'wb'))

print(withheld)
print()
print([x['input']+ " "+ x['correct_output'] for x in train_data])
print()
print([x['input'] for x in val_data])
print()
print([x['input'] + " "+ x['correct_output'] for x in test_data])
