# import argparse
import torch, torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
# import torchvision.datasets as datasets
# import torch.utils.data as data
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
# from torch.autograd import Variable
import torchvision.models as models
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time, os, copy, numpy as np
from tqdm import tqdm
import pickle

from xai_VOC import *

from pdb import set_trace as bp

dataset_dir = '/home/XAI_exp/VOC_bbox_refined_dataset/ViT/VOC_filtered/correct_preds_and_annotations.pt'
model_path = '/home/XAI_exp/finetune_bbox_VOC/ViT/models/ViT-finetuned-voc.pt'
save_explanations_path = 'explanation_masks'

voc_cls_maps = np.load('utilities/voc_cls.npy')

class VOC_XAI_Dataset(Dataset):
    """Face Landmarks dataset."""
    def __init__(self, images, labels, annts, device):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.images = images.to(device)
        self.labels = labels.to(device)
        self.annts = annts

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.images[idx], self.labels[idx], self.annts[idx]

device_id = "cuda:0"
device = torch.device(device_id if torch.cuda.is_available() else "cpu")

images, labels, annts = torch.load(dataset_dir)

xai_dataset = VOC_XAI_Dataset(torch.tensor(images), torch.tensor(labels), annts, device)
xai_dataset_size = len(xai_dataset)
xai_dataloader = torch.utils.data.DataLoader(dataset=xai_dataset,
                                    batch_size=1,
                                    shuffle=False)


model = models.vit_b_16(weights='IMAGENET1K_SWAG_LINEAR_V1')
num_ftrs = model.heads.head.in_features
model.heads.head = nn.Linear(num_ftrs, 21)

model.load_state_dict(torch.load(model_path, map_location=device_id))
model = model.to(device)

def batch_predict(images):
    localtransf = transforms.ToTensor()
    model.eval()
    batch = torch.stack(tuple(localtransf(i) for i in images), dim=0)

    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device(device_id)
    model.to(device)
    batch = batch.to(device)
    
    logits = model(batch)
    # bp()
    probs = F.softmax(logits, dim=1)
    # probs = logits
    return probs.detach().cpu().numpy()

store_ious, store_ious_cls = evaluate_explanation(model, xai_dataloader, xai_dataset_size, batch_predict, save_explanations_path, expl_thr='mean', device_id = device_id)

np.save('store_ious.npy', store_ious)
with open('store_ious_cls.pkl', 'wb') as f:
    pickle.dump(store_ious_cls, f)