# import argparse
import torch, torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import matplotlib.pyplot as plt
import time, os, copy, numpy as np
from tqdm import tqdm
import pickle

from xai_VOC import evaluate_explanation

from pdb import set_trace as bp

dataset_dir = '/home/XAI_exp/VOC_bbox_refined_dataset/resnet101/VOC_filtered/correct_preds_and_annotations.pt'
model_path = '/home/XAI_exp/finetune_bbox_VOC/resnet101/models/resnet101-finetuned-voc.pt'

voc_cls_maps = np.load('utilities/voc_cls.npy')

class VOC_XAI_Dataset(Dataset):
    """Face Landmarks dataset."""
    def __init__(self, images, labels, annts, device):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.images = images.to(device)
        self.labels = labels.to(device)
        self.annts = annts

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.images[idx], self.labels[idx], self.annts[idx]

device_id = "cuda:0"
device = torch.device(device_id if torch.cuda.is_available() else "cpu")

images, labels, annts = torch.load(dataset_dir)

xai_dataset = VOC_XAI_Dataset(torch.tensor(images), torch.tensor(labels), annts, device)
xai_dataset_size = len(xai_dataset)
xai_dataloader = torch.utils.data.DataLoader(dataset=xai_dataset,
                                    batch_size=1,
                                    shuffle=False)


model = models.resnet101(pretrained=True)

model.avgpool = nn.AdaptiveAvgPool2d(1)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 21)

model.load_state_dict(torch.load(model_path, map_location=device_id))
model = model.to(device)

store_ious, store_ious_cls = evaluate_explanation(model, xai_dataloader, xai_dataset_size, expl_thr='mean', device_id = device_id)

np.save('store_ious.npy', store_ious)
with open('store_ious_cls.pkl', 'wb') as f:
    pickle.dump(store_ious_cls, f)