##################################### load stuff #####################################
from transformers import ViTForImageClassification, AutoImageProcessor
from transformers import get_scheduler
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from model.vit_model import ViT_Classification, GPTConfig
from helpers.prepare_dataset import classification_vit_dataloader
from helpers.utils import eval_vit_performance
from laplace import Laplace
from torch.utils.data import Dataset, DataLoader
from model.dbnn_llm_diag import ViT_DBNN_Diag

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
##################################### set hyperparameters #####################################
finetune_attn = 1
dataset_name, num_classes ='cifar-10', 10
num_epochs = 50

##################################### define model #####################################
model_name = "google/vit-base-patch16-224"
huggingface_model = ViTForImageClassification.from_pretrained(model_name)
image_processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)

## init model
config = GPTConfig('small')
model = ViT_Classification(config, huggingface_model, num_classes, return_logits = False)
# hugging face default vit flash attention false
for block in model.transformer.h:
    block.attn.flash = False

model = model.to(device)

## freeze all parameters except classifier and mlp/attn in the last two transformer blocks
for param in model.transformer.parameters():
    param.requires_grad = False

if finetune_attn:
    print("attention value only")
    for name, param in model.transformer.h[11].attn.named_parameters():
        if 'c_attn_v.weight' in name:
            param.requires_grad = True

    for name, param in model.transformer.h[10].attn.named_parameters():
        if 'c_attn_v.weight' in name:
            param.requires_grad = True
else:
    print("all fcs in mlp")
    for name, param in model.transformer.h[11].mlp.named_parameters():
        if 'c_fc.weight' in name:
            param.requires_grad = True
        
        if 'c_proj.weight' in name:
            param.requires_grad = True

    for name, param in model.transformer.h[10].mlp.named_parameters():
        if 'c_fc.weight' in name:
            param.requires_grad = True
        
        if 'c_proj.weight' in name:
            param.requires_grad = True

for param in model.classifier.parameters():
    param.requires_grad = True

##################################### dataset loader #####################################
batch_size = 64
train_loader, test_loader, val_loader = classification_vit_dataloader(dataset_name, batch_size, image_processor)

##################################### training  #####################################
criterion = nn.CrossEntropyLoss()
linear_head_optimizer = optim.Adam(model.classifier.parameters())
optimizer = optim.AdamW(model.parameters(), lr=5e-5, weight_decay = 1e-5)

num_training_steps = num_epochs * len(train_loader)
lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0.1 * num_training_steps,  num_training_steps=num_training_steps) 

best_val_acc = 0

for epoch in range(num_epochs):
    print(f"Epoch [{epoch+1}/{num_epochs}]")
    
    # Train 
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader: 
        inputs, labels = inputs.to(device), labels.to(device).float()
        
        if epoch == 0:
            linear_head_optimizer.zero_grad()
    
            # Forward pass
            pred = model(inputs)
            loss = criterion(pred, labels)
            loss.backward()

            linear_head_optimizer.step()
        else:
            optimizer.zero_grad()

            # Forward pass
            pred = model(inputs)
            loss = criterion(pred, labels)
            loss.backward()
            
            optimizer.step()
            lr_scheduler.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    print(f"    Training Loss: {epoch_loss:.4f}")
    train_accuracy = eval_vit_performance(model, train_loader, device)
    print(f"    Train Accuracy: {train_accuracy:.3f}")
    val_accuracy = eval_vit_performance(model, val_loader, device, full_dataset=True)  
    print(f"    Val Accuracy: {val_accuracy:.3f}")
    
    if epoch != 0:
        if val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
            if finetune_attn:
                torch.save(model.state_dict(), f'{dataset_name}_vit_attn.pt')
            else:
                torch.save(model.state_dict(), f'{dataset_name}_vit_mlp.pt')
    
# load the one with best val accuracy
model = ViT_Classification(config, huggingface_model, num_classes, return_logits = True)
if finetune_attn:
    model.load_state_dict(torch.load(f'{dataset_name}_vit_attn.pt'))
else:
    model.load_state_dict(torch.load(f'{dataset_name}_vit_mlp.pt'))

##################################### fit laplace #####################################
# turn off gradient on the determinstic part
for param in model.transformer.parameters():
    param.requires_grad = False

if finetune_attn:
    print("attention value only")
    for name, param in model.transformer.h[11].attn.named_parameters():
        if 'c_attn_v.weight' in name:
            param.requires_grad = True

    for name, param in model.transformer.h[10].attn.named_parameters():
        if 'c_attn_v.weight' in name:
            param.requires_grad = True
else:
    print("all fcs in mlp")
    for name, param in model.transformer.h[11].mlp.named_parameters():
        if 'c_fc.weight' in name:
            param.requires_grad = True
        
        if 'c_proj.weight' in name:
            param.requires_grad = True

    for name, param in model.transformer.h[10].mlp.named_parameters():
        if 'c_fc.weight' in name:
            param.requires_grad = True
        
        if 'c_proj.weight' in name:
            param.requires_grad = True

for param in model.classifier.parameters():
    param.requires_grad = True

la = Laplace(model, 'classification', subset_of_weights='all', hessian_structure='diag')
# learn Hessian
la.fit(train_loader, progress_bar = True)

# fit prior prec
la.optimize_prior_precision(
    method="gridsearch",
    pred_type="glm",
    link_approx="probit",
    val_loader=val_loader,
    log_prior_prec_min = -5,
    log_prior_prec_max = 5,
    grid_size = 100,
    verbose = True,
    progress_bar = True,
)

##################################### init our model and fit scale factor #####################################
# init our model
if finetune_attn:
    MLP_determinstic = True
    Attn_determinstic = False
else:
    MLP_determinstic = False
    Attn_determinstic = True

alpha_init, scale_init, scale_fit_epoch, scale_fit_lr = 1, 1, 100, 1e-3
dbnn = ViT_DBNN_Diag(model, la.posterior_variance.detach(), scale_init, MLP_determinstic, Attn_determinstic, alpha_init, num_det_blocks=10)

# put input, out_mean, out_var pair into a dataloader
f_mean = []
f_var = []
labels = []

for (X, y) in tqdm(val_loader):
    out_mean, out_var = dbnn.forward_latent(X.to(device))
    f_mean.append(out_mean.detach().cpu().numpy())
    f_var.append(out_var.detach().cpu().numpy())
    labels.append(y.numpy())

f_mean = np.vstack(f_mean)
f_var = np.vstack(f_var)
labels = np.vstack(labels)


class torch_dataset(Dataset):
    def __init__(self, x_data, y_data, z_data):

        self.X = torch.from_numpy(x_data).float()
        self.Y = torch.hstack([torch.from_numpy(y_data).float(), torch.from_numpy(z_data).float()])

    def __len__(self):
        return len(self.Y)

    def __getitem__(self, idx):
        return [self.X[idx], self.Y[idx]]

scale_fit_dataset = torch_dataset(f_mean, f_var, labels)
scale_fit_dataloader = DataLoader(scale_fit_dataset, batch_size=16, shuffle=True)

# fit scale factor
dbnn.alpha = alpha_init
dbnn.scale_factor.data = torch.Tensor([scale_init]).to(device)
train_nlpd = dbnn.fit_scale_factor(scale_fit_dataloader, scale_fit_epoch, scale_fit_lr, False)

##################################### making prediction #####################################
X, y = next(iter(test_loader))

pred = dbnn(X.to(device))

