import ml_collections
import imp
import os
from config.general import aesthetic_general, hps_general, general

base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py"))

def ddpo(config):
    config.project_name = "ddpo-pytorch"
    # whether or not to use classifier-free guidance during training. if enabled, the same guidance scale used during
    # sampling will be used during training.
    config.train.cfg = True
    # clip advantages to the range [-adv_clip_max, adv_clip_max].
    config.train.adv_clip_max = 5
    # the PPO clip range.
    config.train.clip_range = 1e-4
    # the fraction of timesteps to train on. if set to less than 1.0, the model will be trained on a subset of the
    # timesteps for each sample. this will speed up training but reduce the accuracy of policy gradient estimates.
    config.train.timestep_fraction = 1.0

    return config


def aesthetic():
    config = aesthetic_general()
    config = ddpo(config)

    return config

def hps():
    config = hps_general()
    config = ddpo(config)

    return config

def evaluate():
    config = general()
    config = ddpo(config)
    # config.resume_from = "ddpo-pytorch/logs/2024.08.12_15.00.58/checkpoints/checkpoint_198"
    # config.reward_fn = "aesthetic"
    # config.prompt_fn = "eval_simple_animals"
    config.resume_from = "ddpo-pytorch/logs/2024.08.27_00.37.05/checkpoints/checkpoint_198"
    config.reward_fn = "pick"
    config.prompt_fn = "eval_hps_v2_all_qualitative"
    config.only_eval = True
    config.same_evaluation = True
    config.max_vis_images = 2
    config.sample.batch_size = 4
    return config

def get_config(name):
    return globals()[name]()
