from torch import nn
from scripts.classification_datasets.fashionmnist.fashionmnist_torch_dataset import get_dataloaders
from utils.multiple_standard_data_experiments_utils import estimate_all_metrics_plain
import os
from utils.laplace_evaluation_utils import zero_one_loss,ECE_wrapper,NLLLoss_with_log_transform
from utils.model_utils import CNN_nobatchnorm
import torch

'''
This script estimates the required metrics for the laplace approximation of the fashionmnist dataset.
'''

#Hyperparameters
loss_fn = nn.CrossEntropyLoss()
nll_with_log_tranform = NLLLoss_with_log_transform()
batch_size = 40
grid_lambda = 20
min_temperature=0.1
max_temperature=20
grid_prior_variance = None
min_prior_variance=0.1
max_prior_variance=20
n_samples_metrics_la = 100
hessian_structure = 'kron'
image_transforms = False

#Get dataset
path = ''

dir_fashionmnist= ''

test_dataloader,train_dataloader,validation_dataloader = get_dataloaders(dir=dir_fashionmnist,batch_size=batch_size,
                                                                         image_transforms=image_transforms)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


#Change directory
os.chdir(path)

#Set model and estimate bounds
model = CNN_nobatchnorm()
model.to(device)
estimate_all_metrics_plain(train_dataloader,test_dataloader,validation_dataloader,model,likelihood='classification',
                           loss_functions_test=[nll_with_log_tranform,ECE_wrapper,zero_one_loss],
                           loss_functions_test_names=['nll', 'ECE', 'zero_one'],grid_lambda=grid_lambda,
                           min_temperature=min_temperature,max_temperature=max_temperature,grid_prior_variance=grid_prior_variance,
                           min_prior_variance=min_prior_variance,max_prior_variance=max_prior_variance,
                           n_samples=n_samples_metrics_la,hessian_structure=hessian_structure)