import ml_collections
import os



def general():
    config = ml_collections.ConfigDict()

    ###### General ######
    config.project_name = "DiffusionSampleBaseline"
    config.eval_prompt_fn = ''
    config.soup_inference = False
    config.save_freq = 4
    config.resume_from = ""
    config.resume_from_2 = ""
    config.vis_freq = 1
    config.max_vis_images = 2
    config.only_eval = False
    config.run_name = ""
    config.use_lora = True
    config.kl_coeff = 1.
    
    # prompting
    config.prompt_fn = "simple_animals"
    config.reward_fn = "aesthetic"
    config.debug =False
    # mixed precision training. options are "fp16", "bf16", and "no". half-precision speeds up training significantly.
    config.mixed_precision  = "fp16"
    # number of checkpoints to keep before overwriting old ones.
    config.num_checkpoint_limit = 10
    # run name for wandb logging and checkpoint saving -- if not provided, will be auto-generated based on the datetime.
    config.run_name = ""
    # top-level logging directory for checkpoint saving.
    config.logdir = "logs"
    # random seed for reproducibility.
    config.seed = 42    
    # number of epochs to train for. each epoch is one round of sampling from the model followed by training on those
    # samples.
    config.num_epochs = 100    

    # allow tf32 on Ampere GPUs, which can speed up training.
    config.allow_tf32 = True

    config.visualize_train = False
    config.visualize_eval = True

    
    config.same_evaluation = True
    
    
    ###### Training ######    
    config.train = train = ml_collections.ConfigDict()
    # whether to use the 8bit Adam optimizer from bitsandbytes.
    train.use_8bit_adam = False
    # learning rate.
    train.learning_rate = 3e-4
    # Adam beta1.
    train.adam_beta1 = 0.9
    # Adam beta2.
    train.adam_beta2 = 0.999
    # Adam weight decay.
    train.adam_weight_decay = 1e-4
    # Adam epsilon.
    train.adam_epsilon = 1e-8 
    # maximum gradient norm for gradient clipping.
    train.max_grad_norm = 1.0

    ###### Sampling ######
    config.sample = sample = ml_collections.ConfigDict()
    # number of sampler inference steps.
    sample.num_steps = 50
    # eta parameter for the DDIM sampler. this controls the amount of noise injected into the sampling process, with 0.0
    # being fully deterministic and 1.0 being equivalent to the DDPM sampler.
    sample.eta = 1.0
    # classifier-free guidance weight. 1.0 is no guidance.
    sample.guidance_scale = 5.0 # 7.5 in original AlignProp
    # batch size (per GPU!) to use for sampling.
    sample.batch_size = 1
    # number of batches to sample per epoch. the total number of samples per epoch is `num_batches_per_epoch *
    # batch_size * num_gpus`.
    sample.num_batches_per_epoch = 2

    ###### Pretrained Model ######
    config.pretrained = pretrained = ml_collections.ConfigDict()
    # base model to load. either a path to a local directory, or a model name from the HuggingFace model hub.
    pretrained.model = "runwayml/stable-diffusion-v1-5"
    # revision of the model to load.
    pretrained.revision = "main"

    ###### Per-Prompt Stat Tracking ######
    # when enabled, the model will track the mean and std of reward on a per-prompt basis and use that to compute
    # advantages. set `config.per_prompt_stat_tracking` to None to disable per-prompt stat tracking, in which case
    # advantages will be calculated using the mean and std of the entire batch.
    config.per_prompt_stat_tracking = ml_collections.ConfigDict()
    # number of reward values to store in the buffer for each prompt. the buffer persists across epochs.
    config.per_prompt_stat_tracking.buffer_size = 16
    # the minimum number of reward values to store in the buffer before using the per-prompt mean and std. if the buffer
    # contains fewer than `min_count` values, the mean and std of the entire batch will be used instead.
    config.per_prompt_stat_tracking.min_count = 16

    return config



def set_config_batch(config, total_samples_per_epoch, total_batch_size, train_gpu_capacity=4, sample_gpu_capacity=4):
    #  Samples per epoch
    config.train.total_samples_per_epoch = total_samples_per_epoch  #(~~~~ this is desired ~~~~)
    config.train.num_gpus = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
    
    assert config.train.total_samples_per_epoch%config.train.num_gpus==0, "total_samples_per_epoch must be divisible by num_gpus"
    config.train.samples_per_epoch_per_gpu = config.train.total_samples_per_epoch//config.train.num_gpus
    
    #  Total batch size
    config.train.total_batch_size = total_batch_size  #(~~~~ this is desired ~~~~)
    assert config.train.total_batch_size%config.train.num_gpus==0, "total_batch_size must be divisible by num_gpus"
    config.train.batch_size_per_gpu = config.train.total_batch_size//config.train.num_gpus
    config.train.batch_size_per_gpu_available = train_gpu_capacity    #(this quantity depends on the gpu used)
    assert config.train.batch_size_per_gpu%config.train.batch_size_per_gpu_available==0, "batch_size_per_gpu must be divisible by batch_size_per_gpu_available"
    
    config.train.batch_size = config.train.batch_size_per_gpu_available #for coherence with ddpo
    
    config.train.gradient_accumulation_steps = config.train.batch_size_per_gpu//config.train.batch_size_per_gpu_available
    
    assert config.train.samples_per_epoch_per_gpu%config.train.batch_size_per_gpu_available==0, "samples_per_epoch_per_gpu must be divisible by batch_size_per_gpu_available"
    config.train.data_loader_iterations  = config.train.samples_per_epoch_per_gpu//config.train.batch_size_per_gpu_available    
    
    assert total_samples_per_epoch % total_batch_size ==0, "total_samples_per_epoch must be divisible by total_batch_size"
    config.train.num_inner_epochs = total_samples_per_epoch // total_batch_size

    config.sample.batch_size = sample_gpu_capacity
    assert total_samples_per_epoch % (config.train.num_gpus * sample_gpu_capacity) == 0, "total_samples_per_epoch must be divisible by config.train.num_gpus * sample_gpu_capacity"
    config.sample.num_batches_per_epoch = total_samples_per_epoch // (config.train.num_gpus * sample_gpu_capacity)

    return config


def aesthetic_general():
    config = general()
    config.num_epochs = 200
    config.prompt_fn = "simple_animals"
    config.eval_prompt_fn = "eval_simple_animals"
    config.reward_fn = 'aesthetic' # CLIP or imagenet or .... or .. 
    config.max_vis_images = 4
    
    config = set_config_batch(config,total_samples_per_epoch=256,total_batch_size= 128, train_gpu_capacity=4, sample_gpu_capacity=4)

    return config



def hps_general():
    config = general()
    config.num_epochs = 200
    config.prompt_fn = "hps_v2_all"
    config.eval_prompt_fn = 'eval_hps_v2_all'
    config.reward_fn = 'hps'
    config.max_vis_images = 4

    config = set_config_batch(config, total_samples_per_epoch=256,total_batch_size= 128, train_gpu_capacity=4, sample_gpu_capacity=4)

    return config



def evaluate_soup():
    config = general()
    config.only_eval = True
    
    config.reward_fn = 'aesthetic'
    config.prompt_fn = "simple_animals"    
    config.debug = False
    config.same_evaluation = True
    config.max_vis_images = 10
    
    config.soup_inference = True
    config.resume_from = '<CHECKPOINT_NAME>'
    # Use checkpoint name for resume_from_2 as stablediffusion to interpolate between stable diffusion and resume_from
    config.resume_from_2 = '<CHECKPOINT_NAME>'
    config.mixing_coef_1 = 0.0
    config.sample.batch_size = 4
    return config


def evaluate():
    config = general()
    config.reward_fn = "pick"
    config.prompt_fn = "eval_hps_v2_all_qualitative"
    config.only_eval = True
    config.same_evaluation = True
    config.max_vis_images = 3
    config.sample.batch_size = 4
    return config


def get_config(name):
    return globals()[name]()