class BaseConfig:
    def __init__(self):
        self.image_size = 32
        self.classifier_attention_resolutions = (32, 16, 8)
        self.classifier_width = 64
        self.classifier_depth = 2
        self.fps = "fp16"
        self.batch_size = 512
        self.lr = 0.001
        self.device = "cuda"
        self.ema_decay = 0.9999
        self.out_dir = "guided_classifier"
        self.gradient_accumulation_steps = 1
        self.num_epoches = 500
        self.num_train_timesteps = 1000
        self.lr_warmup_steps = 500
        #self.num_classes = 10
        self.unet_dim = 128
        self.dim_mults = (1, 1, 2, 2, 4, 8)
        self.accuracy_show = 500
        self.model_dir="google/ddpm-cifar10-32"
        self.share_num_classes=10
    @property
    def num_classes(self):
        return self.share_num_classes
    @property
    def out_dim(self):
        return self.share_num_classes
    @num_classes.setter
    def num_classes(self, value):
        self.share_num_classes = value
    @out_dim.setter
    def out_dim(self, value):
        self.share_num_classes = value
    
class OneSampleConfig(BaseConfig):
    def __init__(self):
        super().__init__()
        self.image_size = 32
        self.classifier_attention_resolutions = (32, 16, 8)
        self.classifier_width = 64
        self.classifier_depth = 2
        self.fps = "fp16"
        self.batch_size = 1
        self.lr = 0.001
        self.device = "cuda"
        self.ema_decay = 0.9999
        self.out_dir="guided_classifier_one_sample"
        self.gradient_accumulation_steps = 1
        self.num_epoches = 500
        self.num_train_timesteps = 1000
        self.lr_warmup_steps = 500
        self.out_dim = 2
        self.unet_dim = 128
        self.dim_mults = (1, 1, 2, 2, 4, 8)
        self.accuracy_show = 500
class Generate_ClassCond_Config(OneSampleConfig):
    def __init__(self,dataset_name,classifier_choose=52,generate_choose=1500):
       
        super().__init__()
        self.classification_target = 6
        self.torch_device = "cuda"
        self.num_classes = 10
        self.model="encoderuent"
        self.out_dir="guided_classifier_"+dataset_name
        self.classifier_choose=classifier_choose
        self.inversion_model_dir="./"+self.out_dir+"/epoch"+str(classifier_choose)+".pth"
        self.class_cond=True
        self.batch_size=16
        self.use_ddim=True
        self.out_dim=10
        self.self_train=False
        self.ema=False
        self.dataset_name=dataset_name
        self.classifier_type="EncoderUnet"
        if(dataset_name=="celebahq" or dataset_name=="celebahq-256" or dataset_name=="celebahq-synthetic"):
            self.num_classes=self.out_dim=307
            self.model_dir="google/ddpm-ema-celebahq-256"
            self.image_size=256
            self.dim_mults=(1,1,2,2,2,4,8)
            if(self.dataset_name.find("synthetic")!=-1):
                self.model_dir="google/ddpm-celebahq-256"
                self.classifier_type="ResNet"
                self.dim_mults=[3, 4, 6, 3]
        elif(dataset_name=="celebahq-128"):
            #print(config.out_dir)
            self.image_size=128
            self.num_classes=307
            self.dim_mults=(1,2,4,8)
            self.unet_dim=64
            self.model_dir="./ddpm-celebhq-128/model_EMA_epoch"+str(generate_choose)+".pt"
            self.self_train=True
            self.ema=True
        elif(dataset_name=="cifar10"):
            self.out_dir="guided_classifier"
            self.inversion_model_dir="./"+self.out_dir+"/epoch"+str(classifier_choose)+".pth"
            self.train=False
            self.model_dir="google/ddpm-cifar10-32"
        elif(dataset_name=="Cifar10-synthetic"):
            self.inversion_model_dir="./"+self.out_dir+"/epoch"+str(classifier_choose)+".pth"
            self.train=False
            self.model_dir="google/ddpm-cifar10-32"
            self.classifier_type="ResNet"
    #resnet34
            self.dim_mults=[3, 4, 6, 3]
        elif(dataset_name=="celebahq-128-synthetic"):
            self.image_size=128
            self.num_classes=307
            self.dim_mults=(1,2,4,8)
            self.unet_dim=64
            self.model_dir="./ddpm-celebhq-128/model_EMA_epoch"+str(generate_choose)+".pt"
            self.self_train=True
            self.ema=True
            self.classifier_type="ResNet"
            self.dim_mults=[3, 4, 6, 3]
class Generate_OneSample_Config(OneSampleConfig):
    def __init__(self):
        super().__init__()
        self.classification_target = 6
        self.torch_device = "cuda"
        self.num_classes = 2
        self.model="encoderuent"
        #self.out_dir=
        self.inversion_model_dir="./"+self.out_dir+"/epoch52.pth"
        self.class_cond=True
        self.batch_size=900
        self.use_ddim=False
        self.out_dim=2
class CelebAHQ_Config_Base(BaseConfig):
    def __init__(self):
        super().__init__()
        self.image_size=256
        self.num_classes=307
        self.dim_mults=(1,1,2,2,2,4,8)
        self.fps="bf16"
        self.model_dir="google/ddpm-celebahq-256"
