import numpy as np 
from tqdm import tqdm
import os
import random

    

def pretty_print(batch, full = True, max_num = None, in_sequence_form = False):
    examples, target_mask, input_mask = batch
    
    for b in range(examples.shape[0]):
        if max_num is not None and b == max_num:
            break
        print(f"====== instance {b+1} ========")
        s = ""
        started = False
        for idx, t in enumerate(examples[b]):
            if t == 2:
                s+= "-> "
            elif t == 1:
                if in_sequence_form:
                    s+= "\\n "
                else:    
                    s+= "\n"
            elif t == 3:
                s+= "<sep> "
            elif t == 4:
                s+= "( "
            elif t == 5:
                s+= ") "    
            elif t == 0:
                if in_sequence_form:
                    s+= "_ "
                else:    
                    s+= ""
            else:
                s+= str(t.item()) + " "
            
            if input_mask[b][idx] and not started:
                started = True
                s+="["
            if not input_mask[b][idx] and started:
                started = False
                if t == 1:
                    s = s[:-1]
                    s+="\\n]\n"
        print(s)
#         print(s)
#         if full:
#             print("-------")
#             if in_sequence_form:
#                 s = ""
#                 for idx, t in enumerate(examples[b][index:]):
#                     if t == 2:
#                         s+= "-> "
#                     elif t == 1:
#                         if in_sequence_form:
#                             s+= "\\n"
#                         else:    
#                             s+= "\n"
#                     elif t == 3:
#                         s+= "<sep> "
#                     elif t == 0:
#                         s+= "_ "    
#                     else:
#                         s+= str(t.item()) + " "
#                 print(s)
#             else:
#                 print(examples[b][target_mask[b] == 1].tolist())

                
        
class SynDataset:
    def __init__(self, n, data_dir = ".", task = "lookup", num_context_statements_range = None,  max_word_length = 1, max_seq_length = 256, num_source_symbols = 30, num_target_symbols = 30, max_args_length = 5, max_formula_length = 5, max_num_const = None, token_shift = False, log_perm_group = 1, remap_tokens = None, on_fly = False):
        self.n = n
        self.data_dir = data_dir
        self.task = task
        self.max_word_length = max_word_length
        self.max_seq_length = max_seq_length
        self.num_context_statements_range = num_context_statements_range
        if self.num_context_statements_range is None:
           
            if self.task == "subseq":
                self.num_context_statements_range = (1,4)
            elif self.task == "subseq_control":
                self.num_context_statements_range = (1, 4) 
                
            else:
                self.num_context_statements_range = (3, 7) 
            
        self.num_source_symbols = num_source_symbols
        self.num_target_symbols = num_target_symbols
        
        self.max_args_length = max_args_length
        self.max_formula_length = max_formula_length
#         self.data = [self.gen_chunky_lookup_data(np.random.choice(range(*num_context_statements_range))) for _ in range(n)]
        self.max_num_const = self.max_formula_length
        if max_num_const is not None:
            self.max_num_const = max_num_const
        self.transformation = None
        self.num_special_tokens = 10
        
        self.token_shift = token_shift
        self.log_perm_group = log_perm_group
        self.remap_tokens = remap_tokens
        self.remap_tokens_frequencies = None

        # check that frequency array for remap tokens is correct length 
        if type(remap_tokens) != int:
            self.remap_tokens_frequencies = remap_tokens
            self.remap_tokens = len(remap_tokens)

        # if self.remap_tokens and self.remap_tokens_frequencies is not None:
        #     if len(self.remap_tokens_frequencies) != self.remap_tokens:
        #         print(f"Vocab token frequencies do not match specified vocab size {self.remap_tokens}. Sampling uniformly.")
        #         self.remap_tokens_frequencies = None

        self.on_fly = on_fly

        if task == "add_vec":
            def modify_embed_matrx(weight):
                v = np.random.randn(weight.shape[0], 1, weight.shape[2])
                weight[:, self.num_special_tokens + self.num_source_symbols: self.num_special_tokens + self.num_source_symbols + self.num_source_symbols] = weight[:, self.num_special_tokens:self.num_special_tokens + self.num_source_symbols] + v.to(weight.device)
                return weight
            self.transformation = modify_embed_matrx
        elif task == "add_vec_control":
            def modify_embed_matrx(weight):
                for i in range(6):
                    v = np.random.randn(weight.shape[0], 1, weight.shape[2])
                    weight[:, self.num_special_tokens + (i + 1) *self.num_source_symbols: self.num_special_tokens + (i + 2) *self.num_source_symbols] = weight[:, self.num_special_tokens:self.num_special_tokens + self.num_source_symbols] + v.to(weight.device)
                return weight
            self.transformation = modify_embed_matrx
            
        if not self.on_fly :
            self.cache()
        
    def cache(self):
        if os.path.exists(self.data_dir + f"/{self.task}_{self.n}.npy"):
            self.data = np.load(self.data_dir + f"/{self.task}_{self.n}.npy")
            return
                
#         self.data = []
#         with multiprocessing.Pool(3) as pool:
#             for d in tqdm(pool.imap_unordered( data_gen, np.random.choice(range(*self.num_context_statements_range), self.n)), total = self.n ):
#                 self.data.append(d)
                
        self.data = [ self.get_instance() for _ in tqdm(range(self.n))]        
        np.save(self.data_dir + f"/{self.task}_{self.n}.npy", self.data)
 

    def __getitem__(self, index):
        if self.on_fly :
            return self.get_instance()
        else:
            return self.data[index]
    

        
    def get_instance(self):
        if self.task == "lookup":
            data = self.gen_chunky_lookup_data(np.random.choice(range(*self.num_context_statements_range)))
        elif self.task == "fixed_lookup":
            data =  self.gen_chunky_fixed_lookup_data(np.random.choice(range(*self.num_context_statements_range)))
        elif self.task == "fixed_lookup_all":
            data =  self.gen_chunky_fixed_lookup_data(np.random.choice(range(*self.num_context_statements_range)), supervise_all=True)
        elif self.task == "add_vec":
            data =  self.gen_chunky_add_vec_data(np.random.choice(range(*self.num_context_statements_range)))
        elif self.task == "subseq":
            data =  self.gen_chunky_subseq_data(np.random.choice(range(*self.num_context_statements_range)))
        elif self.task == "subseq_control":
            num_context_statements = np.random.choice(range(*self.num_context_statements_range))
            num_controls = np.random.choice(range(1, num_context_statements + 1))
            data =  self.gen_chunky_subseq_control_data(num_context_statements, num_controls)
        elif self.task == "add_vec_control":
            num_context_statements = np.random.choice(range(*self.num_context_statements_range))
            num_controls = np.random.choice(range(1, num_context_statements + 1))
            data =  self.gen_chunky_add_vec_control_data(num_context_statements, num_controls)
        elif self.task == "subseq_lookup":
            data =  self.gen_chunky_subseq_lookup_data(np.random.choice(range(*self.num_context_statements_range)))
        
        seq_orig, target_mask, input_mask = data
        if self.remap_tokens:
            remap_arr = np.arange(self.remap_tokens)
            if self.remap_tokens_frequencies is None:
                # sample with uniform probability 
                np.random.shuffle(remap_arr)
            else:
                # sample according to given frequencies
                remap_arr = np.random.choice(
                    remap_arr, 
                    size=(np.max(seq_orig) + 1,), 
                    replace=False, 
                    p=self.remap_tokens_frequencies)

            seq = remap_arr[seq_orig]
        else:
            seq = seq_orig

        return seq, target_mask, input_mask
    
    
    def __len__(self):
        return self.n
#         return len(self.data)
    

    def gen_chunky_lookup_data(self, num_context_statements):    

        num_query_statements = 1            

        source_symbols = np.random.choice(range(self.num_special_tokens, self.num_special_tokens + self.num_source_symbols), (num_context_statements, self.max_word_length))
        target_symbols = np.random.choice(range(self.num_special_tokens + self.num_source_symbols, self.num_special_tokens + self.num_target_symbols + self.num_source_symbols), (num_context_statements, self.max_word_length))

        source_lengths = np.random.choice(range(0,self.max_word_length), num_context_statements) + 1
        target_lengths = np.random.choice(range(0,self.max_word_length), num_context_statements) + 1

        query_index = np.random.choice(range(0, num_context_statements))


        final_seq = np.zeros(self.max_seq_length, dtype = np.int32)
        target_mask = np.zeros(self.max_seq_length)
        input_mask = np.zeros(self.max_seq_length)   
        
        total_length = source_lengths.sum() + target_lengths.sum() + source_lengths[query_index] + target_lengths[query_index]  +  2 * (num_context_statements + 1) 
        
#         index = np.random.choice(range(0, self.max_seq_length - total_length + 1))
        index =  0 
        for i in range(num_context_statements):
            source_length = source_lengths[i]
            target_length = target_lengths[i]

            final_seq[index : index + source_length] = source_symbols[i][:source_length]
            index += source_length

            final_seq[index] = 2
            index += 1

            final_seq[index : index + target_length] = target_symbols[i][:target_length]
            index += target_length

            final_seq[index] = 1
            index += 1

        # query
        source_length = source_lengths[query_index]
        target_length = target_lengths[query_index]

        final_seq[index : index + source_length] = source_symbols[query_index][:source_length]
        index += source_length

        final_seq[index] = 2
        index += 1

        final_seq[index : index + target_length] = target_symbols[query_index][:target_length]
        target_mask[index : index + target_length] = 1
        input_mask[index - 1 : index + target_length - 1] = 1 
        index += target_length

        final_seq[index] = 1
        target_mask[index : index + 1] = 1
        input_mask[index - 1 : index] = 1 
        index += 1

        return (final_seq, target_mask, input_mask)
    
    
    def gen_chunky_fixed_lookup_data(self, num_context_statements, supervise_all=False):    

        num_query_statements = 1            

        max_group_id = 10 ** self.log_perm_group
        
        offset = 0
        if self.token_shift:
            offset = max_group_id * 2
                
        
        source_symbols = np.random.choice(range(self.num_special_tokens + offset, self.num_special_tokens + max_group_id + offset), (num_context_statements, self.max_word_length))
        target_symbols = source_symbols + max_group_id

        source_lengths = np.random.choice(range(0,self.max_word_length), num_context_statements) + 1
        target_lengths = source_lengths

        query_index = np.random.choice(range(0, num_context_statements))


        final_seq = np.zeros(self.max_seq_length, dtype = np.int32)
        target_mask = np.zeros(self.max_seq_length)
        input_mask = np.zeros(self.max_seq_length)   
        
        total_length = source_lengths.sum() + target_lengths.sum() + source_lengths[query_index] + target_lengths[query_index]  +  2 * (num_context_statements + 1) 
        
#         index = np.random.choice(range(0, self.max_seq_length - total_length + 1))
        index =  0 
        for i in range(num_context_statements):
            source_length = source_lengths[i]
            target_length = target_lengths[i]

            final_seq[index : index + source_length] = source_symbols[i][:source_length]
            index += source_length

            final_seq[index] = 2
            index += 1

            final_seq[index : index + target_length] = target_symbols[i][:target_length]
            if supervise_all:
                target_mask[index : index + target_length] = 1
                input_mask[index - 1 : index + target_length - 1] = 1 
            index += target_length

            final_seq[index] = 1
            if supervise_all:
                target_mask[index : index + 1] = 1
                input_mask[index - 1 : index] = 1 
            index += 1

        # query
        source_length = source_lengths[query_index]
        target_length = target_lengths[query_index]

        final_seq[index : index + source_length] = source_symbols[query_index][:source_length]
        index += source_length

        final_seq[index] = 2
        index += 1

        final_seq[index : index + target_length] = target_symbols[query_index][:target_length]
        target_mask[index : index + target_length] = 1
        input_mask[index - 1 : index + target_length - 1] = 1 
        index += target_length

        final_seq[index] = 1
        target_mask[index : index + 1] = 1
        input_mask[index - 1 : index] = 1 
        index += 1

        return (final_seq, target_mask, input_mask)
    
    
    def gen_chunky_add_vec_data(self, num_context_statements):    

        num_query_statements = 1
        num_context_statements += num_query_statements

        source_symbols = np.random.choice(range(self.num_special_tokens, self.num_special_tokens + self.num_source_symbols), (num_context_statements, self.max_word_length))
        target_symbols = source_symbols + self.num_source_symbols
        
        source_lengths = np.random.choice(range(0,self.max_word_length), num_context_statements) + 1
        target_lengths = source_lengths


        final_seq = np.zeros(self.max_seq_length, dtype = np.int32)
        target_mask = np.zeros(self.max_seq_length)
        input_mask = np.zeros(self.max_seq_length)   
        
        total_length = source_lengths.sum() + target_lengths.sum()  +  2 * (num_context_statements)
        
#         index = np.random.choice(range(0, self.max_seq_length - total_length + 1))
        index = 0
        for i in range(num_context_statements):
            source_length = source_lengths[i]
            target_length = target_lengths[i]

            final_seq[index : index + source_length] = source_symbols[i][:source_length]
            index += source_length

            final_seq[index] = 2
            index += 1

            final_seq[index : index + target_length] = target_symbols[i][:target_length]
            if i == num_context_statements - 1:
                target_mask[index : index + target_length] = 1
                input_mask[index - 1 : index + target_length - 1] = 1 
        
            index += target_length

            final_seq[index] = 1
            if i == num_context_statements - 1:
                target_mask[index : index + 1] = 1
                input_mask[index - 1 : index] = 1         
            index += 1

        
        return (final_seq, target_mask, input_mask)
    
    def gen_chunky_subseq_data(self, num_context_statements):    

        num_query_statements = 1
        num_context_statements += num_query_statements
        while True:
        # sample formula
            seq_length = np.random.choice(range(self.max_args_length)) + 1
            formula_length = np.random.choice(range(1, self.max_formula_length)) + 1
            
            # num_const = formula_length - 1
            num_const = min(self.max_num_const, formula_length - 1)
            formula = np.random.choice(range(seq_length + num_const), formula_length)

            const_symbols = np.random.choice(range(self.num_special_tokens, self.num_special_tokens + self.num_source_symbols), (formula_length - 1, self.max_word_length))
            const_lengths = np.random.choice(range(0,self.max_word_length), formula_length - 1) + 1

            # construct examples 
#             group_id = 0
#             for i in range(min(self.log_perm_group, formula_length)):
#                 group_id = group_id*10 + formula[i]
            offset = 0
#             if self.token_shift:
#                 offset = (group_id + 1)* self.num_source_symbols
            
            var_symbols = np.random.choice(range(offset + self.num_special_tokens, offset + self.num_special_tokens + self.num_source_symbols), (num_context_statements, seq_length, self.max_word_length))
            
            
            var_lengths = np.random.choice(range(0,self.max_word_length), (num_context_statements, seq_length)) + 1


            final_seq = np.zeros(self.max_seq_length, dtype = np.int32)
            target_mask = np.zeros(self.max_seq_length)
            input_mask = np.zeros(self.max_seq_length)   

            index = 0

            try:
                for i in range(num_context_statements):
                    var_length = var_lengths[i]

                    for j in range(seq_length):
                        word_length = var_length[j]
                        word = var_symbols[i][j]
                        final_seq[index : index + word_length] = word[:word_length]
                        index += word_length

                        if j < seq_length - 1:
                            final_seq[index] = 3
                            index += 1

                    final_seq[index] = 2
                    index += 1

                    for j, s in enumerate(formula):
                        if s < seq_length:
                            word_length = var_length[s]
                            word = var_symbols[i][s]
                        else:
                            word_length = const_lengths[s - seq_length]
                            word = const_symbols[s - seq_length]

                        final_seq[index : index + word_length] = word[:word_length]
                        if i == num_context_statements - 1:
                            target_mask[index : index + word_length] = 1
                            input_mask[index - 1 : index + word_length - 1] = 1 

                        index += word_length
                        
                        if j < formula_length - 1:
                            final_seq[index] = 3
                            if i == num_context_statements - 1:
                                target_mask[index : index + 1] = 1
                                input_mask[index - 1 : index] = 1 
                            index += 1

                    final_seq[index] = 1
                    if i == num_context_statements - 1:
                        target_mask[index : index + 1] = 1
                        input_mask[index - 1 : index] = 1 
                    index += 1
                break
            except:
                pass
#                 print("too long, resample", index, word_length, self.max_seq_length)
            
#         total_length = index    
#         rnd_start = np.random.choice(range(0, self.max_seq_length - total_length + 1))
#         final_seq[rnd_start:rnd_start + total_length] = final_seq[:index]
#         final_seq[:rnd_start] = 0
        
#         target_mask[rnd_start:rnd_start + total_length] = target_mask[:index]
#         target_mask[:rnd_start] = 0
        
#         input_mask[rnd_start:rnd_start + total_length] = input_mask[:index]
#         input_mask[:rnd_start] = 0
        
        return (final_seq, target_mask, input_mask)

    ################
    ######   control
    ################
    def gen_chunky_subseq_control_data(self, num_context_statements, num_controls):    

        num_query_statements = 1
        num_context_statements += num_query_statements
        while True:
        # sample formula
            seq_length = np.random.choice(range(self.max_args_length)) + 1
            seq_length = seq_length + 1
            
            
            control_index = np.random.choice(range(seq_length), num_controls)
            
            formula_length = np.random.choice(range(1, self.max_formula_length), num_controls) + 1
            
            formula = [np.random.choice(list(set(range(seq_length + formula_length[i] - 1)) - set([control_index[i]])), formula_length[i]) for i in range(num_controls)]

            control_symbols = np.random.choice(range(self.num_special_tokens, self.num_special_tokens + self.num_source_symbols), (num_controls, self.max_word_length))
            control_lengths = np.random.choice(range(0,self.max_word_length), num_controls) + 1
            
            const_symbols = np.random.choice(range(self.num_special_tokens, self.num_special_tokens + self.num_source_symbols), (num_controls, self.max_formula_length, self.max_word_length))
            const_lengths = np.random.choice(range(0,self.max_word_length), (num_controls, self.max_formula_length)) + 1

            # control assignment
            assert num_context_statements > num_controls
            statement_control_map = list(range(num_controls)) +  list(np.random.choice(range(num_controls), num_context_statements - num_controls - 1))
            random.shuffle(statement_control_map)
            statement_control_map += list(np.random.choice(range(num_controls), 1))
            
            # construct examples                               
            var_symbols = np.random.choice(range(self.num_special_tokens, self.num_special_tokens + self.num_source_symbols), (num_context_statements, seq_length, self.max_word_length))
            var_lengths = np.random.choice(range(0,self.max_word_length), (num_context_statements, seq_length)) + 1


            final_seq = np.zeros(self.max_seq_length, dtype = np.int32)
            target_mask = np.zeros(self.max_seq_length)
            input_mask = np.zeros(self.max_seq_length)   

            index = 0

            try:
#             if True:   
                for i in range(num_context_statements):
                    control_type = statement_control_map[i]
                    
                    var_length = var_lengths[i]
                    

                    for j in range(seq_length):
                        if j != control_index[control_type]:
                            word_length = var_length[j]
                            word = var_symbols[i][j]
                        else:
                            word_length = control_lengths[control_type]
                            word = control_symbols[control_type]
                            
                        final_seq[index : index + word_length] = word[:word_length]
                        index += word_length

                        if j < seq_length - 1:
                            final_seq[index] = 3
                            index += 1

                    final_seq[index] = 2
                    index += 1
                    for j, s in enumerate(formula[control_type]):
                        if s < seq_length:
                            word_length = var_length[s]
                            word = var_symbols[i][s]
                        else:
                            word_length = const_lengths[control_type][s - seq_length]
                            word = const_symbols[control_type][s - seq_length]

                        final_seq[index : index + word_length] = word[:word_length]
                        if i == num_context_statements - 1:
                            target_mask[index : index + word_length] = 1
                            input_mask[index - 1 : index + word_length - 1] = 1 

                        index += word_length
                        
                        if j < formula_length[control_type] - 1:
                            final_seq[index] = 3
                            if i == num_context_statements - 1:
                                target_mask[index : index + 1] = 1
                                input_mask[index - 1 : index] = 1 
                            index += 1

                    final_seq[index] = 1
                    if i == num_context_statements - 1:
                        target_mask[index : index + 1] = 1
                        input_mask[index - 1 : index] = 1 
                    index += 1
                break
            except:
#                 pass
                print("too long, resample", index, word_length, self.max_seq_length)
            
        total_length = index    
        rnd_start = np.random.choice(range(0, self.max_seq_length - total_length + 1))
        final_seq[rnd_start:rnd_start + total_length] = final_seq[:index]
        final_seq[:rnd_start] = 0
        
        target_mask[rnd_start:rnd_start + total_length] = target_mask[:index]
        target_mask[:rnd_start] = 0
        
        input_mask[rnd_start:rnd_start + total_length] = input_mask[:index]
        input_mask[:rnd_start] = 0
        
        return (final_seq, target_mask, input_mask)

    def gen_chunky_add_vec_control_data(self, num_context_statements, num_controls):    

        num_query_statements = 1
        num_context_statements += num_query_statements
        while True:
            control_index = np.random.choice(range(0, 2), num_controls)
            control_symbols = np.random.choice(range(self.num_special_tokens, self.num_special_tokens + self.num_source_symbols), (num_controls, self.max_word_length))
            control_lengths = np.random.choice(range(0,self.max_word_length), num_controls) + 1



            # control assignment
            assert num_context_statements > num_controls
            statement_control_map = list(range(num_controls)) +  list(np.random.choice(range(num_controls), num_context_statements - num_controls - 1))
            random.shuffle(statement_control_map)
            statement_control_map += list(np.random.choice(range(num_controls), 1))


            # construct examples
            source_symbols = np.random.choice(range(self.num_special_tokens, self.num_special_tokens + self.num_source_symbols), (num_context_statements, self.max_word_length))
            target_symbols = np.array(source_symbols) + self.num_source_symbols * (np.array(statement_control_map).reshape(-1,1) + 1)

            source_lengths = np.random.choice(range(0,self.max_word_length), num_context_statements) + 1
            target_lengths = source_lengths


            final_seq = np.zeros(self.max_seq_length, dtype = np.int32)
            target_mask = np.zeros(self.max_seq_length)
            input_mask = np.zeros(self.max_seq_length)   

            total_length = source_lengths.sum() + target_lengths.sum()  +  2 * (num_context_statements)

#             index = np.random.choice(range(0, self.max_seq_length - total_length + 1))
            index = 0
            try:
                for i in range(num_context_statements):
                    source_length = source_lengths[i]
                    target_length = target_lengths[i]
                    control_type = statement_control_map[i]

                    if control_index[control_type] == 0:
                        word_length = control_lengths[control_type]
                        final_seq[index : index + word_length] = control_symbols[control_type][:word_length]
                        index += word_length

                        final_seq[index : index + 1] = 3
                        index += 1

                    final_seq[index : index + source_length] = source_symbols[i][:source_length]
                    index += source_length

                    if control_index[control_type] == 1:
                        final_seq[index : index + 1] = 3
                        index += 1

                        word_length = control_lengths[control_type]
                        final_seq[index : index + word_length] = control_symbols[control_type][:word_length]
                        index += word_length

                    final_seq[index] = 2
                    index += 1
                    final_seq[index : index + target_length] = target_symbols[i][:target_length]
                    if i == num_context_statements - 1:
                        target_mask[index : index + target_length] = 1
                        input_mask[index - 1 : index + target_length - 1] = 1 

                    index += target_length

                    final_seq[index] = 1
                    if i == num_context_statements - 1:
                        target_mask[index : index + 1] = 1
                        input_mask[index - 1 : index] = 1         
                    index += 1
                break
            except:
#                 pass
                print("too long, resample", index, word_length, self.max_seq_length)

        
        return (final_seq, target_mask, input_mask)

    
    ################
    ######  hierarchy
    ################
    def gen_hierarchy_data(self, num_context_statements):    

        num_query_statements = 1
        num_context_statements += num_query_statements
        while True:
        # sample formula
            seq_length = np.random.choice(range(self.max_args_length)) + 1
            formula_length = np.random.choice(range(1, self.max_formula_length)) + 1
            
            formula = np.random.choice(range(seq_length + formula_length - 1), formula_length)

            const_symbols = np.random.choice(range(self.num_special_tokens, self.num_special_tokens + self.num_source_symbols), (formula_length - 1, self.max_word_length))
            const_lengths = np.random.choice(range(0,self.max_word_length), formula_length - 1) + 1

            # construct examples                               
            var_symbols = np.random.choice(range(self.num_special_tokens, self.num_special_tokens + self.num_source_symbols), (num_context_statements, seq_length, self.max_word_length*9))
            var_lengths = np.random.choice(range(0,self.max_word_length), (num_context_statements, seq_length, 9)) + 1
            group_var1 = np.random.choice(range(1,4), (num_context_statements, seq_length)) 
            group_var2 = np.random.choice(range(1,4), (num_context_statements, seq_length, 3)) 


            final_seq = np.zeros(self.max_seq_length, dtype = np.int32)
            target_mask = np.zeros(self.max_seq_length)
            input_mask = np.zeros(self.max_seq_length)   

            index = 0

            try:
                for i in range(num_context_statements):
                    var_length = var_lengths[i]

                    for j in range(seq_length):
                        word_length = var_length[j]
                        word = var_symbols[i][j]
                        final_seq[index : index + word_length] = word[:word_length]
                        index += word_length

                        if j < seq_length - 1:
                            final_seq[index] = 3
                            index += 1

                    final_seq[index] = 2
                    index += 1

                    for j, s in enumerate(formula):
                        if s < seq_length:
                            word_length = var_length[s]
                            word = var_symbols[i][s]
                        else:
                            word_length = const_lengths[s - seq_length]
                            word = const_symbols[s - seq_length]

                        final_seq[index : index + word_length] = word[:word_length]
                        if i == num_context_statements - 1:
                            target_mask[index : index + word_length] = 1
                            input_mask[index - 1 : index + word_length - 1] = 1 

                        index += word_length
                        
                        if j < formula_length - 1:
                            final_seq[index] = 3
                            if i == num_context_statements - 1:
                                target_mask[index : index + 1] = 1
                                input_mask[index - 1 : index] = 1 
                            index += 1

                    final_seq[index] = 1
                    if i == num_context_statements - 1:
                        target_mask[index : index + 1] = 1
                        input_mask[index - 1 : index] = 1 
                    index += 1
                break
            except:
                pass
#                 print("too long, resample", index, word_length, self.max_seq_length)
            
        total_length = index    
        rnd_start = np.random.choice(range(0, self.max_seq_length - total_length + 1))
        final_seq[rnd_start:rnd_start + total_length] = final_seq[:index]
        final_seq[:rnd_start] = 0
        
        target_mask[rnd_start:rnd_start + total_length] = target_mask[:index]
        target_mask[:rnd_start] = 0
        
        input_mask[rnd_start:rnd_start + total_length] = input_mask[:index]
        input_mask[:rnd_start] = 0
        
        return (final_seq, target_mask, input_mask)

    
    
    
    
    
    ################
    ######  compositional
    ################
    
    def gen_chunky_subseq_lookup_data(self, num_context_statements):    
        assert self.max_word_length == 1
        num_query_statements = 1
        num_context_statements += num_query_statements
        
        while True:
            try:
        # sample formula
                seq_length = np.random.choice(range(self.max_args_length)) + 1
                formula_length = np.random.choice(range(1, self.max_formula_length)) + 1

                # num_const = formula_length - 1
                num_const = min(self.max_num_const, formula_length - 1)
                formula = np.random.choice(range(seq_length + num_const), formula_length)

                const_symbols = np.random.choice(range(self.num_special_tokens, self.num_special_tokens + self.num_source_symbols), (formula_length - 1, self.max_word_length))
                const_lengths = np.random.choice(range(0,self.max_word_length), formula_length - 1) + 1

                # construct examples 
                offset = 0
                var_symbols = np.random.choice(range(offset + self.num_special_tokens, offset + self.num_special_tokens + self.num_source_symbols), (num_context_statements, seq_length, self.max_word_length), replace = False)
                var_lengths = np.random.choice(range(0,self.max_word_length), (num_context_statements, seq_length)) + 1


                target_symbols = np.random.choice(range(self.num_special_tokens + self.num_source_symbols, self.num_special_tokens + self.num_target_symbols + self.num_source_symbols), (num_context_statements* seq_length + self.max_num_const, self.max_word_length), replace = False)
                target_lengths = np.random.choice(range(0,self.max_word_length), (num_context_statements* seq_length + self.max_num_const)) + 1



                final_seq = np.zeros(self.max_seq_length, dtype = np.int32)
                target_mask = np.zeros(self.max_seq_length)
                input_mask = np.zeros(self.max_seq_length)   

                index = 0

                ### lookup table
                all_order = np.random.choice(num_context_statements * seq_length + self.max_num_const, num_context_statements * seq_length + self.max_num_const, replace = False)
                for order in all_order:
                    if order <  num_context_statements * seq_length:
                        i = order // seq_length
                        j = order % seq_length
                        source_symbols = var_symbols[i][j]
                        source_length = var_lengths[i][j]
                    else:
                        i = order - num_context_statements * seq_length
                        source_symbols = const_symbols[i]
                        source_length = const_lengths[i]

                    target_length = target_lengths[order]

                    final_seq[index : index + source_length] = source_symbols[:source_length]
                    index += source_length

                    final_seq[index] = 2
                    index += 1

                    final_seq[index : index + target_length] = target_symbols[order][:target_length]
                    index += target_length

                    final_seq[index] = 1
                    index += 1

            
            ### subseq 
            
                for i in range(num_context_statements):
                    var_length = var_lengths[i]

                    for j in range(seq_length):
                        word_length = var_length[j]
                        word = var_symbols[i][j]
                        final_seq[index : index + word_length] = word[:word_length]
                        index += word_length

                        if j < seq_length - 1:
                            final_seq[index] = 3
                            index += 1

                    final_seq[index] = 2
                    index += 1

                    for j, s in enumerate(formula):
                        if s < seq_length:
                            word_length = target_lengths[i* seq_length + s]
                            word = target_symbols[i* seq_length + s]
                        else:
                            word_length = target_lengths[num_context_statements * seq_length + s - seq_length]
                            word = target_symbols[num_context_statements * seq_length  + s - seq_length]
                            

                        final_seq[index : index + word_length] = word[:word_length]
                        if i == num_context_statements - 1:
                            target_mask[index : index + word_length] = 1
                            input_mask[index - 1 : index + word_length - 1] = 1 

                        index += word_length
                        
                        if j < formula_length - 1:
                            final_seq[index] = 3
                            if i == num_context_statements - 1:
                                target_mask[index : index + 1] = 1
                                input_mask[index - 1 : index] = 1 
                            index += 1

                    final_seq[index] = 1
                    if i == num_context_statements - 1:
                        target_mask[index : index + 1] = 1
                        input_mask[index - 1 : index] = 1 
                    index += 1
                break
            except Exception as e: 
                print(e)
                # pass
                print("too long, resample")
            
#         total_length = index    
#         rnd_start = np.random.choice(range(0, self.max_seq_length - total_length + 1))
#         final_seq[rnd_start:rnd_start + total_length] = final_seq[:index]
#         final_seq[:rnd_start] = 0
        
#         target_mask[rnd_start:rnd_start + total_length] = target_mask[:index]
#         target_mask[:rnd_start] = 0
        
#         input_mask[rnd_start:rnd_start + total_length] = input_mask[:index]
#         input_mask[:rnd_start] = 0
        
        return (final_seq, target_mask, input_mask)
    
if __name__ == "__main__":
#     import pdb;pdb.set_trace()
#     pretty_print(DataLoader(SynDataset(5, task = "add_vec_control",max_seq_length = 200), batch_size = 5).next_batch())
#     pretty_print(DataLoader(SynDataset(5, task = "add_vec_control",max_seq_length = 200), batch_size = 5).next_batch(), in_sequence_form = True)

#     dataset = SynDataset(5000 * 100 * 10 * 3, task = "lookup")
    
#     dataset = SynDataset(5000 * 100 * 10 * 3, task = "lookup")
#     dataset = SynDataset(5000 * 100 * 10 * 3, task = "add_vec")
    dataset = SynDataset(1000, task = "subseq", on_fly = True)
    import pdb;pdb.set_trace()
    # dataset = SynDataset(5000 * 100 * 10 * 3, task = "subseq", generate = True)
#     import pdb;pdb.set_trace()
    