# import argparse
import torch, torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
# import torchvision.datasets as datasets
# import torch.utils.data as data
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
# from torch.autograd import Variable
import torchvision.models as models
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time, os, copy, numpy as np
from tqdm import tqdm
import pickle

from xai_VOC import *
from xai_VOC_scoregenerator import *

from pdb import set_trace as bp

dataset_dir = '/home/XAI_exp/VOC_bbox_refined_dataset/ViT/VOC_filtered/correct_preds_and_annotations.pt'
model_path = '/home/XAI_exp/finetune_bbox_VOC/ViT/models/ViT-finetuned-voc.pt'
save_explanations_path = 'explanation_masks'

voc_cls_maps = np.load('utilities/voc_cls.npy')

class VOC_XAI_Dataset(Dataset):
    """Face Landmarks dataset."""
    def __init__(self, images, labels, annts, device):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.images = images.to(device)
        self.labels = labels.to(device)
        self.annts = annts

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.images[idx], self.labels[idx], self.annts[idx]

device_id = "cuda:0"
device = torch.device(device_id if torch.cuda.is_available() else "cpu")

images, labels, annts = torch.load(dataset_dir)

xai_dataset = VOC_XAI_Dataset(torch.tensor(images), torch.tensor(labels), annts, device)
xai_dataset_size = len(xai_dataset)
xai_dataloader = torch.utils.data.DataLoader(dataset=xai_dataset,
                                    batch_size=1,
                                    shuffle=False)


model = models.vit_b_16(weights='IMAGENET1K_SWAG_LINEAR_V1')
num_ftrs = model.heads.head.in_features
model.heads.head = nn.Linear(num_ftrs, 21)

model.load_state_dict(torch.load(model_path, map_location=device_id))
model = model.to(device)

# def batch_predict(images):
#     localtransf = transforms.ToTensor()
#     model.eval()
#     batch = torch.stack(tuple(localtransf(i) for i in images), dim=0)

#     # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     device = torch.device(device_id)
#     model.to(device)
#     batch = batch.to(device)
    
#     logits = model(batch)
#     # bp()
#     probs = F.softmax(logits, dim=1)
#     # probs = logits
#     return probs.detach().cpu().numpy()

# store_ious, store_ious_cls = evaluate_explanation(model, xai_dataloader, xai_dataset_size, save_explanations_path, expl_thr='mean', device_id = device_id)
score_calc(model, xai_dataloader, xai_dataset_size, save_explanations_path, expl_thr='mean', device_id = device_id)
