# import argparse
import torch, torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
# import torchvision.datasets as datasets
# import torch.utils.data as data
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
# from torch.autograd import Variable
import torchvision.models as models
import matplotlib.pyplot as plt
import time, os, copy, numpy as np
from tqdm import tqdm
import pickle

from xai_VOC import evaluate_explanation

from pdb import set_trace as bp

dataset_dir = '/home/XAI_exp/VOC_bbox_refined_dataset/ViT/VOC_filtered/correct_preds_and_annotations.pt'
model_path = '/home/XAI_exp/finetune_bbox_VOC/ViT/models/ViT-finetuned-voc.pt'

voc_cls_maps = np.load('utilities/voc_cls.npy')

class VOC_XAI_Dataset(Dataset):
    """Face Landmarks dataset."""
    def __init__(self, images, labels, annts, device):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.images = images.to(device)
        self.labels = labels.to(device)
        self.annts = annts

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.images[idx], self.labels[idx], self.annts[idx]

device_id = "cuda:0"
device = torch.device(device_id if torch.cuda.is_available() else "cpu")

images, labels, annts = torch.load(dataset_dir)

xai_dataset = VOC_XAI_Dataset(torch.tensor(images), torch.tensor(labels), annts, device)
xai_dataset_size = len(xai_dataset)
xai_dataloader = torch.utils.data.DataLoader(dataset=xai_dataset,
                                    batch_size=1,
                                    shuffle=False)


# model = models.vit_b_16(weights='IMAGENET1K_SWAG_LINEAR_V1')
model = models.vit_b_16(weights='ViT_B_16_Weights.IMAGENET1K_V1')
num_ftrs = model.heads.head.in_features
model.heads.head = nn.Linear(num_ftrs, 21)

model.load_state_dict(torch.load(model_path, map_location=device_id))
model = model.to(device)

store_ious, store_ious_cls = evaluate_explanation(model, xai_dataloader, xai_dataset_size, expl_thr=0.2, device_id = device_id)

np.save('store_ious.npy', store_ious)
with open('store_ious_cls.pkl', 'wb') as f:
    pickle.dump(store_ious_cls, f)