from Generator import *

# 2D arc horizontal vs vertical generators
class Gen_arc_extend(Generator):
    # Allow only 1 zero row, and blank row, and only in training sets
    zero_flag=False

    def __init__(self, save_folder, file_prefix, min_len=32, max_len=33):
        super().__init__(save_folder, file_prefix, min_len, max_len)
        self.gen_func = self.gen_extend_pair_mc

    def gen_extend_pair_mc(self, seq_len, io_idx):
        """
        Generate extend, horizontally and vertically, size=3x3
        """

        other_colors = list(range(1,self.max_digits+2))

        input, output=[],[]
        for i in range(3):
            row_color = random.choice(other_colors)
            other_colors.remove(row_color)

            if row_color == self.max_digits+1:
                if io_idx==3:
                    continue
                # Blank row, only in training sets
                input_r,output_r = [self.back_ground]*3,[self.back_ground]*3
            else:
                input_r,output_r = [self.back_ground]*3,[self.back_ground]*3
                if self.zero_flag:
                    init_space = randrange(1,3)
                else:
                    init_space = randrange(3)

                if init_space == 0:
                    self.zero_flag=True

                input_r[init_space]=row_color
                for j in range(init_space+1):
                    output_r[j]=row_color

            input.append(input_r)
            output.append(output_r)

        return input, output


gen1=Gen_arc_extend(save_folder="../dataset/arc_hv_tasks/arc_extend_h",file_prefix="arc_extend_h")
#gen1.gen_json(dryrun=False, gen_type="2d")
for i in range(50):
    gen1.gen_json(dryrun=False, gen_type="2d")
