import torch

from library import load_datasets
from library import metrics
from library import misc
from library import model_io
from library import models
from library import results_json
from library import train
from library import baseline_configs
from library import configs
import matplotlib.pyplot as plt

# DO NOT EVALUATE IN TRAIN MODE!

datasets = {'mnist': (configs.Dataset.MNIST, 10), 
            'fmnist': (configs.Dataset.FMNIST, 10), 
            'kmnist': (configs.Dataset.KMNIST, 10), 
            'qmnist': (configs.Dataset.QMNIST, 10), 
            'emnist_letters': (configs.Dataset.EMNIST_LETTERS, 26), 
            'emnist_balanced': (configs.Dataset.EMNIST_BALANCED, 47),
            'cifar10': (configs.Dataset.CIFAR10, 10), 
            'cifar100': (configs.Dataset.CIFAR100, 100)
}

layer_sizes = [256, 512, 1024]

for size in layer_sizes:
    for key in datasets.keys():
        for run in range(3):
            config = baseline_configs.Get_MNIST_Config()
        
            ############
            config.data_config.dataset = datasets[key][0]
            config.model_config.last_layer_neurons = (config.model_config.num_neurons//datasets[key][1]) * datasets[key][1]
            
            config.experiment_config.experiment_name = f"{key}_ffn_baseline_final_{size}_{run}"

            config.train_config.save_model_on = 'bin'
            
            config.train_config.extensive_eval = False
            config.train_config.eval_freq = 4
            config.train_config.learning_rate = 0.00001
            config.train_config.num_epochs = 100
            config.train_config.extensive_eval_train = True
            config.test_config.extensive_eval_test = True
        
            config.model_config.distanceLayer = False
            config.model_config.distanceLayer2 = False
            config.model_config.use_mygroupsum = False
            config.model_config.use_groupsum = False
            config.model_config.full_ffn = True
            config.model_config.use_ffn = False
            config.model_config.use_ffbinary = False
            config.model_config.ffn_layer_size = size
        
            config.model_config.seed = run
            ############
            
            model_config = config.model_config
            print(model_config)
            
            misc.set_seed(config.model_config.seed)
            
            train_loader, validation_loader, test_loader, bin_loader, test_bin_loader = load_datasets.load_dataset(config)
            network = models.create_model(config)
            
            loss_fn = torch.nn.CrossEntropyLoss()
            
            optimizer = torch.optim.Adam(network.parameters(), lr=config.train_config.learning_rate)
            
            if config.data_config.device == "cuda":
                network = network.cuda()
            
            results = results_json.ResultsJSON(config)
            train.train(model=network, 
                loss_fn=loss_fn, 
                optimizer=optimizer, 
                train_loader=train_loader, 
                validation_loader=validation_loader, 
                binarized_loader=bin_loader,
                test_loader=test_loader,
                test_loader_bin=test_bin_loader,
                results=results, 
                config=config)
            
            model_io.save_model(network, config=config, model_path="./models/", model_name=config.experiment_config.experiment_name)