import torch
import time
from TrainCondition import train, eval, eval_gen_tail, eval_gen_all, generate_features


def main(model_config=None):
    modelConfig = {
        "state": "gen_tail",  # train or eval
        "epoch": 50,
        "batch_size": 128,
        "eval_batch_size": 500,
        "data_dim": 256,
        "T": 500,
        "channel": 128,
        "width": 512,
        "channel_mult": [1, 2, 2, 2],
        "num_res_blocks": 2,
        "n_blocks": 16,
        "dropout": 0.15,
        "lr": 1e-4,
        "multiplier": 2.5,
        "beta_1": 1e-4,
        "beta_T": 0.028,
        "img_size": 32,
        "grad_clip": 1.,
        "device": "cuda:0",
        "w": 1.8,
        "num_labels": 100,  # Number of classes
        "preprocessing": "none",  # 'none', 'zscore', 'minmax', 'robust'
        "save_dir": "./CheckpointsCondition/",
        "training_load_weight": None,
        "test_load_weight": "ckpt_best_Film.pt",
        "nrow": 8
    }
    if model_config is not None:
        modelConfig = model_config
    if modelConfig["state"] == "train":
        train(modelConfig)
    elif modelConfig["state"] == "eval":
        eval(modelConfig)
    elif modelConfig["state"] == "gen_tail":
        eval_gen_tail(modelConfig)
    elif modelConfig["state"] == "gen_all":
        eval_gen_all(modelConfig)
    elif modelConfig["state"] == "gen_custom":
        target_classes = [0, 1, 2, 3, 4]
        generate_features(modelConfig, target_classes=target_classes, samples_per_class=1000)
    else:
        print("Invalid state. Use: 'train', 'eval', 'gen_tail', 'gen_all', or 'gen_custom'")


if __name__ == '__main__':
    main()
