#%%
import time
import pickle as pic
import numpy as np
import torch as to
from torch.utils.data import DataLoader

import DatasetTools as dts
from MLP import NN

def tau_schedule(start, finish, steps):

    c = (finish / start) ** (1 / steps)

    sched = start * np.ones(steps + 1)

    for ix in range(steps + 1):
        sched[ix] *= c ** ix
    
    return sched
  

if __name__ == "__main__":

    print('start!')
    t0 = time.time()

    #%%
    learned_params0 = np.zeros(15).astype(str)
    learned_params0[::3] = 'biases'
    learned_params0[2::3] = 'softmask'
    learned_params0[1::3] = 'softmask'
    seed0 = np.random.uniform(0, 100000, size=15).astype(int)
    seed_w0 = np.random.uniform(0, 100000, size=5).astype(int).repeat(3)

    #%%
    for ix in range(15):

        learned_params = learned_params0[ix]
        seed = seed0[ix]
        seed_w = seed_w0[ix]
        
        rng = to.Generator()
        rng.manual_seed(int(seed))
        rng_w = to.Generator()
        rng_w.manual_seed(int(seed_w))
    
        dataset_train = dts.get_mnist_images_dataset(one_hot_output=True, cluster=False)
        dataset_val = dts.get_mnist_images_dataset(one_hot_output=False, cluster=False)
        val_size = 10000
        _, dataset = to.utils.data.random_split(dataset_train, [val_size, len(dataset_train) - val_size], generator=to.Generator().manual_seed(42))
        test_dataset, _ = to.utils.data.random_split(dataset_val, [val_size, len(dataset_val) - val_size], generator=to.Generator().manual_seed(42))

        # DEFINE MODELS
        n_input = 784
        n_hidden = 10000
        n_output = 10
        activation = 'relu'
        if learned_params == 'biases':
            lr = 0.01
        elif learned_params == 'softmask':
            lr = 0.01

        init_a = - 0.01
        init_b = - init_a
        sig_w = 1 / np.sqrt(n_input)
        loss_fn = to.nn.CrossEntropyLoss()

        model = NN(n_input, n_hidden, n_output, learning_rate=lr, to_learn=learned_params, activation_type=activation, crit=loss_fn)
        model.init_uniform(a=init_a, b=init_b, generator=rng)
        model.init_weights_normal(0, sig_w, generator=rng_w)

        # TRAINING PARAMS AND DATA
        n_epochs = 30
        tau = tau_schedule(1, 0.001, n_epochs - 1)
        batch_size = 512
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, generator=rng)
    
        # TRAIN MODEL
        model.mytrain(dataloader, tau=tau, n_epochs=n_epochs, wandb=False)

        # EVALUATE
        test_dataloader = DataLoader(test_dataset, batch_size=1024, generator=rng)
        unmasked_acc = model.evaluate(testloader=test_dataloader)
        print(unmasked_acc)

        # GET HIDDEN STATES AND CALCULATE MASKS
        h = model.get_hidden_states(test_dataset)
        hvar = to.var(h, 0)
        if learned_params == 'softmask':
            mask = model.alphas.detach() > 0
        if learned_params == 'biases':
            mask = hvar >= 0.1  # HARDCODED FOR GAUSSIAN
        
        masked_acc = model.mask_evaluate(test_dataloader, mask)

        #%%
        # FILES, FOLDERS
        save_folder = 'data'
        save_file = 'summary'

        save_dict1 = {
        'hvar' : hvar,
        'mask' : mask,
        'unmasked_acc' : unmasked_acc,
        'masked_acc' : masked_acc,
        'history' : model.loss_history,
        'lr' : lr,
        'learned_params' : learned_params,
        'seed': seed,
        'seedw': seed_w
        }

        
        with open(save_folder + '/' + save_file + '_' + str(learned_params) + '_' + str(ix) + '.pkl', 'wb') as handle:
            pic.dump(save_dict1, handle, protocol=pic.HIGHEST_PROTOCOL)

    print(f'done. script runtime: {(time.time() - t0) / 60} minutes')


# %%
