import torch


class HParams:
    def __init__(self):
        self.data_location = 'quickdraw/'  # location of  of origin data
        # self.category = ["airplane.npz"]
        self.category = ["airplane.npz", "angel.npz", "alarm clock.npz", "apple.npz",
                         "butterfly.npz", "belt.npz", "bus.npz",
                         "cake.npz", "cat.npz", "clock.npz", "eye.npz", "fish.npz",
                         "pig.npz", "sheep.npz", "spider.npz", "The Great Wall of China.npz",
                         "umbrella.npz"]
        self.model_save = "model_save"


        self.batch_size = 200
        self.lr = 0.01
        self.lr_decay = 0.99999
        self.min_lr = 0.001
        self.max_seq_length = 200
        self.min_seq_length = 0
        self.epochs = 1000000
        self.Nmax = 0  # max stroke number of a sketch
        self.graph_number = 1 + 20  # the number of graph for each sketch,first for global
        self.graph_picture_size = 225  # size of graph 128
        self.mask_prob = 0.1
        self.use_cuda = torch.cuda.is_available()


hp = HParams()