import ml_collections
import os
from config.general import aesthetic_general, hps_general, general



def align_prop(config):
    config.project_name = "align-prop"
    # AlignProp specific
    config.grad_checkpoint = True
    config.truncated_backprop = False
    config.truncated_backprop_rand = False
    config.truncated_backprop_minmax = (35,45)
    config.trunc_backprop_timestep = 100

    config.train.loss_coeff = 0.01

    config.sample.guidance_scale = 7.5

    return config



def aesthetic():
    config = aesthetic_general()
    config = align_prop(config)
    config.num_epochs = 200
    config.prompt_fn = "simple_animals"

    config.eval_prompt_fn = "eval_simple_animals"

    config.reward_fn = 'aesthetic' # CLIP or imagenet or .... or .. 
    config.train.max_grad_norm = 5.0    
    config.train.loss_coeff = 0.01
    config.train.learning_rate = 1e-3
    config.train.adam_weight_decay = 0.1
    
    config.save_freq = 1
    config.num_epochs = 7
    config.num_checkpoint_limit = 14
    config.truncated_backprop_rand = True
    config.truncated_backprop_minmax = (0,50)
    config.trunc_backprop_timestep = 40
    config.truncated_backprop = True
    
    config.aesthetic_target = 10
    config.grad_scale = 1

    return config

def aesthetic_k1():
    config = aesthetic()

    config.truncated_backprop_rand = False
    config.trunc_backprop_timestep = 49

    return config



def hps():
    config = hps_general()
    config = align_prop(config)
    config.num_epochs = 200
    config.prompt_fn = "hps_v2_all"
    config.eval_prompt_fn = 'eval_hps_v2_all'
    config.reward_fn = 'hps'
    config.per_prompt_stat_tracking = { 
        "buffer_size": 32,
        "min_count": 16,
    }
    config.train.max_grad_norm = 5.0    
    config.train.loss_coeff = 0.01
    config.train.learning_rate = 1e-3
    config.train.adam_weight_decay = 0.1

    config.trunc_backprop_timestep = 40
    config.truncated_backprop = True
    config.truncated_backprop_rand = True
    config.truncated_backprop_minmax = (0,50)    
    
    return config

def hps_k1():
    config = hps()

    config.truncated_backprop_rand = False
    config.trunc_backprop_timestep = 49

    return config

def evaluate():
    config = general()
    config = align_prop(config)

    config.sample.eta = 0.0

    # config.resume_from = "AlignProp/logs/iconic-pyramid-34/checkpoints/checkpoint_10"
    # config.reward_fn = "aesthetic"
    # config.prompt_fn = "eval_simple_animals"
    config.resume_from = "AlignProp/logs/twilight-sponge-54/checkpoints/checkpoint_15"
    config.reward_fn = "pick"
    config.prompt_fn = "eval_hps_v2_all_qualitative"
    config.only_eval = True
    config.same_evaluation = True
    config.max_vis_images = 2
    config.sample.batch_size = 4
    return config

def evaluate_kl():
    config = general()
    config = align_prop(config)

    config.sample.eta = 1.0

    # config.resume_from = "AlignProp/logs/decent-shadow-82/checkpoints/checkpoint_10"
    # config.reward_fn = "aesthetic"
    # config.prompt_fn = "eval_simple_animals"
    config.resume_from = "AlignProp/logs/rosy-pyramid-83/checkpoints/checkpoint_15"
    config.reward_fn = "pick"
    config.prompt_fn = "eval_hps_v2_all"
    config.only_eval = True
    config.same_evaluation = True
    config.max_vis_images = 8
    config.sample.batch_size = 4
    return config

def get_config(name):
    return globals()[name]()