import torch

from library import load_datasets
from library import metrics
from library import misc
from library import model_io
from library import models
from library import results_json
from library import train
from library import baseline_configs
from library import configs
import matplotlib.pyplot as plt

start = 5
end = 1000
n = 5  # number of steps

num_classes = [int(start * (end / start) ** (i / (n - 1))) for i in range(n)]
num_classes = [min(num, 1000) for num in num_classes][4:]

for num in num_classes:
    config = baseline_configs.Get_MNIST_Config()

    ############
    config.data_config.dataset = configs.Dataset.IMAGENET32
    # config.data_config.num_classes = num
    # config.model_config.num_neurons = 256000
    config.train_config.batch_size = 16
    config.test_config.batch_size = 64
    config.train_config.learning_rate = 0.0005
    # config.model_config.last_layer_neurons = (256000//num) * num
    # config.model_config.tau = 1
    
    
    config.experiment_config.experiment_name = f"imagenet_num_classes_05_{1000}"

    config.model_config.num_layers = 6
    
    config.train_config.extensive_eval = True
    config.train_config.eval_freq = 4
    config.train_config.num_epochs = 100
    config.train_config.extensive_eval_train = True
    
    config.test_config.extensive_eval_test = True

    config.model_config.distanceLayer = False
    config.model_config.use_mygroupsum = False
    config.model_config.use_groupsum = False
    config.model_config.full_ffn = True
    config.model_config.use_ffn = False
    config.model_config.use_ffbinary = False
    ############
    
    model_config = config.model_config
    print(model_config)
    
    misc.set_seed(config.model_config.seed)
    
    train_loader, validation_loader, test_loader, bin_loader, test_bin_loader = load_datasets.load_dataset(config)
    network = models.create_model(config)
    
    loss_fn = torch.nn.CrossEntropyLoss()
    
    optimizer = torch.optim.Adam(network.parameters(), lr=config.train_config.learning_rate)
    
    if config.data_config.device == "cuda":
        network = network.cuda()
    
    results = results_json.ResultsJSON(config)
    train.train(model=network, 
        loss_fn=loss_fn, 
        optimizer=optimizer, 
        train_loader=train_loader, 
        validation_loader=validation_loader, 
        binarized_loader=bin_loader,
        test_loader=test_loader,
        test_loader_bin=test_bin_loader,
        results=results, 
        config=config)
    
    model_io.save_model(network, config=config, model_path="./models/", model_name=config.experiment_config.experiment_name)

    # ==== FREE MEMORY ====
    del network
    del optimizer
    del train_loader
    del validation_loader
    del test_loader
    del bin_loader
    torch.cuda.empty_cache()
    import gc
    gc.collect()
    # ======================









