import ml_collections
import imp
import os
from config.general import evaluate

def freedom():
    config = evaluate()
    config.use_lora = False
    config.project_name = "FreeDoM"
    config.sample.batch_size = 2

    config.freedom = ml_collections.ConfigDict()
    config.freedom.rho = 0.2
    config.freedom.time_travel_repeat = 1

    config.sample.num_steps = 100
    config.sample.eta = 0.

    config.max_vis_images = 8

    return config

def aesthetic():
    config = freedom()
    config.reward_fn = "aesthetic"
    config.prompt_fn = "eval_simple_animals"

    return config

def hps():
    config = freedom()
    config.reward_fn = "hps"
    config.prompt_fn = "eval_hps_v2_all"

    return config

def pick():
    print("PickScore")
    config = freedom()
    config.reward_fn = "pick"
    config.prompt_fn = "eval_hps_v2_all"

    config.sample.batch_size = 1
    config.max_vis_images = 16

    return config

def inpaint():
    config = freedom()
    config.reward_fn = "inpaint"
    config.prompt_fn = "inpaint"
    config.sample.guidance_scale = 0.

    config.inpaint = ml_collections.ConfigDict()
    config.inpaint.x = 0
    config.inpaint.width = 128
    config.inpaint.y = 0
    config.inpaint.height = 128
    config.inpaint.sample_name = "00003"

    config.max_vis_images = 1

    return config


def get_config(name):
    return globals()[name]()
