import argparse
import torch, torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision.models as models
import matplotlib.pyplot as plt
import time, os, copy, numpy as np
from tqdm import tqdm

from train_model import infer_model

from pdb import set_trace as bp
from collections import OrderedDict

dataset_dir = '/home/datasets'
model_path = '/home/XAI_exp/finetune_bbox_VOC/ViT/models/ViT-finetuned-voc.pt'

def get_pil_transform():
    transf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224)
    ])

    return transf

def get_preprocess_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
    transf = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])

    return transf   

pill_transf = get_pil_transform()
preprocess_transform = get_preprocess_transform()

def init_transform():
    transf = transforms.Compose([
        pill_transf,
        preprocess_transform
    ])

    return transf 

init_transf = init_transform()

voc_cls_maps = np.load('utilities/voc_cls.npy')

def area_rect(xmin, x_max, y_min, y_max):
    xmin, x_max, y_min, y_max = float(xmin), float(x_max), float(y_min), float(y_max)
    xDiff = abs(x_max - xmin) # Using absolute value to ignore negatives
    yDiff = abs(y_max - y_min)
    area = xDiff * yDiff
    return area

def collate_fn_voc_detection(data):
    bad = [0, 255]
    images, labels, annts = [], [], []
    for el in data:
        cls, area = np.unique(np.array(el[1]), return_counts=True)
        area = area[~np.isin(cls, bad)]
        cls = cls[~np.isin(cls, bad)]
        gnd_cls = cls[np.argmax(area)]
        
        images.append(el[0])
        labels.append(gnd_cls)
        annts.append(np.array(el[1]))
    return(torch.stack(images), torch.tensor(labels), annts)

image_datasets = {x: datasets.VOCSegmentation(root=dataset_dir, year='2012', image_set=x, download=False, transform=init_transf)
                  for x in ['train', 'val']}

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=128, shuffle=True, collate_fn=collate_fn_voc_detection)
              for x in ['train', 'val']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

model = models.vit_b_16(weights='IMAGENET1K_SWAG_LINEAR_V1')

num_ftrs = model.heads.head.in_features
model.heads.head = nn.Linear(num_ftrs, 21)

device_id = "cuda:0"
device = torch.device(device_id if torch.cuda.is_available() else "cpu")

model.load_state_dict(torch.load(model_path, map_location=device_id))
model = model.to(device)

infer_model(model, dataloaders, dataset_sizes, device_id = device_id)