import torch

from library import load_datasets
from library import metrics
from library import misc
from library import model_io
from library import models
from library import results_json
from library import train
from library import baseline_configs
from library import configs
import matplotlib.pyplot as plt

tau_start = 0.1
tau_end = 100
tau_values = [tau_start * (tau_end/tau_start) ** (i / 5) for i in range(6)]
tau_values = tau_values[4:]
dropout_values = [0.7]

for tau in tau_values:
    for drop in dropout_values:
        config = baseline_configs.Get_MNIST_Config()
    
        ############
        config.data_config.dataset = configs.Dataset.SYNTHETIC
        config.data_config.num_classes = 2000
        config.data_config.input_size = 784
        config.train_config.save_model_on = 'valid'
        config.data_config.samples_per_class_train = 500
        config.data_config.samples_per_class_test = 100
        config.data_config.fixed_bits_per_class = 40
        config.data_config.lower_bound_fixed = 5
        config.model_config.num_neurons = 64000
        config.train_config.batch_size = 64
        config.test_config.batch_size = 128
        config.train_config.learning_rate = 0.01
        config.model_config.last_layer_neurons = (64000//1000) * 1000
        config.model_config.tau = tau
        config.model_config.dropout_percentage = drop
        
        
        config.experiment_config.experiment_name = f"synthetic_tau_dropout_02_{tau}_{drop}"
    
        config.model_config.num_layers = 6
        
        config.train_config.extensive_eval = True
        config.train_config.eval_freq = 4
        config.train_config.learning_rate = 0.01
        config.train_config.num_epochs = 200
        config.train_config.extensive_eval_train = True
        
        config.test_config.extensive_eval_test = True
    
        config.model_config.distanceLayer = False
        config.model_config.use_mygroupsum = True
        config.model_config.use_groupsum = False
        config.model_config.full_ffn = False
        config.model_config.use_ffn = False
        config.model_config.use_ffbinary = False
        ############
        
        model_config = config.model_config
        print(model_config)
        
        misc.set_seed(config.model_config.seed)
        
        train_loader, validation_loader, test_loader, bin_loader, test_bin_loader = load_datasets.load_dataset(config)
        network = models.create_model(config)
        
        loss_fn = torch.nn.CrossEntropyLoss()
        
        optimizer = torch.optim.Adam(network.parameters(), lr=config.train_config.learning_rate)
        
        if config.data_config.device == "cuda":
            network = network.cuda()
        
        results = results_json.ResultsJSON(config)
        train.train(model=network, 
            loss_fn=loss_fn, 
            optimizer=optimizer, 
            train_loader=train_loader, 
            validation_loader=validation_loader, 
            binarized_loader=bin_loader,
            test_loader=test_loader,
            test_loader_bin=test_bin_loader,
            results=results, 
            config=config)
        
        # model_io.save_model(network, config=config, model_path="./models/", model_name=config.experiment_config.experiment_name)
    
        # ==== FREE MEMORY ====
        del network
        del optimizer
        del train_loader
        del validation_loader
        del test_loader
        del bin_loader
        torch.cuda.empty_cache()
        import gc
        gc.collect()
        # ======================









