from Generator import *

# 2D arc horizontal vs vertical generators
class Gen_arc_extend(Generator):
    # Allow only <= 2 blank rows in each io pair, and must be in training set
    num_blank=0

    def __init__(self, save_folder, file_prefix, min_len=32, max_len=33):
        super().__init__(save_folder, file_prefix, min_len, max_len)
        self.gen_func = self.gen_extend_pair_mc

    def gen_extend_pair_mc(self, seq_len, io_idx):
        """
        Generate extend, horizontally and vertically, size=3x3
        """
        grid_len = 3
        other_colors = list(range(1,self.max_digits+1))
        obj_color = random.choice(other_colors)
        obj_size = randrange(2,4)

        shapes = ['vbar_2','hbar_2','hbar_3','L_0','L_90','L_180','L_270']

        shape = random.choice(shapes)

        shape_sizes = {'vbar_2':[1,2],
                       'hbar_2':[2,1],
                       'hbar_3':[3,1],
                       'L_0':[2,2],
                       'L_90':[2,2],
                       'L_180':[2,2],
                       'L_270':[2,2]
                       }

        init_x = randrange(grid_len - shape_sizes[shape][0]+1)
        init_y = randrange(grid_len - shape_sizes[shape][1])

        input, output=[],[]
        self.num_blank=0

        for i in range(grid_len):
            input_r,output_r = [self.back_ground]*grid_len,[self.back_ground]*grid_len

            input.append(input_r)
            output.append(output_r)

        # draw
        L_cut = [[1,0],[1,1],[0,1],[0,0]]
        if shape.startswith('L'):
            cut_idx = int(shape.split("_")[1]) / 90
            for idx, v in enumerate(L_cut):
                if idx != cut_idx:
                    input[init_x+v[0]][init_y+v[1]]=obj_color
                    output[init_x+v[0]][init_y+v[1]+1]=obj_color
        elif shape=='vbar_2':
            input[init_x][init_y]=obj_color
            input[init_x][init_y+1]=obj_color
            output[init_x][init_y+1]=obj_color
            output[init_x][init_y+2]=obj_color
        elif shape.startswith('hbar'):
            bar_len = int(shape.split("_")[1])
            for i in range(bar_len):
                input[init_x+i][init_y]=obj_color
                output[init_x+i][init_y+1]=obj_color

        return input, output


gen1=Gen_arc_extend(save_folder="../dataset/arc_hv_tasks/arc_downone_h",file_prefix="arc_downone_h")
#gen1.gen_json(dryrun=False, gen_type="2d")
for i in range(50):
    gen1.gen_json(dryrun=False, gen_type="2d")
