from Generator import *

# 2D arc horizontal vs vertical generators
class Gen_arc_extend(Generator):
    # Allow only <= 2 blank rows in each io pair, and must be in training set
    num_blank=0

    def __init__(self, save_folder, file_prefix, min_len=32, max_len=33):
        super().__init__(save_folder, file_prefix, min_len, max_len)
        self.gen_func = self.gen_extend_pair_mc

    def gen_extend_pair_mc(self, seq_len, io_idx):
        """
        Generate extend, horizontally and vertically, size=3x3
        """
        grid_len = randrange(4,7)
        other_colors = list(range(1,self.max_digits+1))

        input, output=[],[]
        self.num_blank=0

        for i in range(grid_len):
            row_color = random.choice(other_colors)
            other_colors.remove(row_color)

            pixel_num = randrange(4)
            print("Begin: ",pixel_num,self.num_blank,io_idx)
            if pixel_num==0:
                if self.num_blank>2 or io_idx==3:
                    pixel_num = randrange(1,4)
                else:
                    self.num_blank+=1

            print("End: ",pixel_num,self.num_blank,io_idx)

            pixel_pos_l = list(range(grid_len))
            random.shuffle(pixel_pos_l)

            input_r,output_r = [self.back_ground]*grid_len,[self.back_ground]*grid_len
            for j in range(pixel_num):
                input_r[pixel_pos_l[j]] = row_color
                output_r[j] = row_color

            input.append(input_r)
            output.append(output_r)

        return input, output


gen1=Gen_arc_extend(save_folder="../dataset/arc_hv_tasks/arc_pile_h",file_prefix="arc_pile_h")
#gen1.gen_json(dryrun=False, gen_type="2d")
for i in range(50):
    gen1.gen_json(dryrun=False, gen_type="2d")
