##################################### load stuff #####################################
from helpers.prepare_dataset import classification_dataloader
from torch.utils.data import Dataset, DataLoader
from model.create_torch_model import create_mlp
import torch
import torch.nn as nn
import numpy as np
from laplace import Laplace
from tqdm import tqdm
from model.dbnn_kron import DBNN_MLP_Kron

device = 'cuda' if torch.cuda.is_available() else 'cpu'
##################################### set hyperparameters #####################################
dataset_name = 'MNIST'
batch_size, lr, weight_decay, n_epoch, network_structure = [64, 1e-3, 1e-5, 100, [784, 128, 64, 10]]

##################################### dataset loader #####################################
train_loader, val_loader, test_loader = classification_dataloader(dataset_name, batch_size, val_ratio = 0.2)

##################################### define model #####################################
model = create_mlp(network_structure, 'relu', 'classification')
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
loss_func = nn.CrossEntropyLoss()

##################################### train model #####################################
for epoch in range(n_epoch):
    train_loss = []

    for X, y in train_loader:
        optimizer.zero_grad()
        loss = loss_func(model(X.to(device)), y.to(device))
        loss.backward()
        optimizer.step()
        
        train_loss.append(loss.item())
    
    print(f'epoch {epoch} train loss {np.mean(train_loss)}')


##################################### fit laplace #####################################
### kron fitting
la = Laplace(model[:-1], 'classification', subset_of_weights='all', hessian_structure='kron')
### learn Hessian
la.fit(train_loader)
### learn prior precision
la.optimize_prior_precision(
    method="gridsearch",
    pred_type="glm",
    link_approx="probit",
    val_loader=val_loader,
    log_prior_prec_min = -5,
    log_prior_prec_max = 5,
    grid_size = 100,
    verbose = True,
    progress_bar = True,
)

##################################### init our model and fit scale factor #####################################
# init our model
scale_init = 1.
dbnn_model_la_kron = DBNN_MLP_Kron(model[:-1], la.H.eigenvectors, la.H.eigenvalues, la.H_facs.kfacs, 'classification', scale_init, la.prior_precision.item())

# put input, out_mean, out_var pair into a dataloader
class torch_dataset(Dataset):
    def __init__(self, x_data, y_data, z_data):

        self.X = torch.from_numpy(x_data).float()
        # y_data correspond to cov, here it's full structure, while fitting only need diagonal part
        self.Y = torch.hstack([torch.from_numpy(y_data).diagonal(dim1=1, dim2=2).float(), torch.from_numpy(z_data).float()])

    def __len__(self):
        return len(self.Y)

    def __getitem__(self, idx):
        return [self.X[idx], self.Y[idx]]

f_mean = []
f_var = []
labels = []

for (X, y) in tqdm(val_loader):
    out_mean, out_var = dbnn_model_la_kron.forward_latent(X.to(device))
    f_mean.append(out_mean.detach().cpu().numpy())
    f_var.append(out_var.detach().cpu().numpy())
    labels.append(y.numpy())

f_mean = np.vstack(f_mean)
f_var = np.vstack(f_var)
labels = np.vstack(labels)

scale_fit_dataset = torch_dataset(f_mean, f_var, labels)
scale_fit_dataloader = DataLoader(scale_fit_dataset, batch_size=16, shuffle=True)

# fit scale factor
alpha_init, scale_init, scale_fit_epoch, scale_fit_lr = 1., 1., 50, 1e-3
dbnn_model_la_kron.alpha = alpha_init
dbnn_model_la_kron.scale_factor.data = torch.Tensor([scale_init]).to(device)
train_nlpd = dbnn_model_la_kron .fit_scale_factor(scale_fit_dataloader, scale_fit_epoch, scale_fit_lr, False)

##################################### making prediction #####################################
X, y = next(iter(test_loader))

pred = dbnn_model_la_kron(X.to(device))

