import torch, torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision.models as models
import matplotlib.pyplot as plt
import time, os, copy, numpy as np
# from livelossplot import PlotLosses
from train_model import train_model

from pdb import set_trace as bp

def get_pil_transform(): 
    transf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224)
    ])    

    return transf

def get_preprocess_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])     
    transf = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])    

    return transf    

pill_transf = get_pil_transform()
preprocess_transform = get_preprocess_transform()

def init_transform():
    transf = transforms.Compose([
        pill_transf,
        preprocess_transform
    ])

    return transf 

init_transf = init_transform()

voc_cls_maps = np.load('utilities/voc_cls.npy')
weights = [327, 268, 395, 260, 365, 213, 590, 539, 566, 151, 269, 632, 237, 265, 1994, 269, 171, 257, 273, 290]
weights = torch.tensor([0.0]+[1/i for i in weights])
weights = weights/weights.sum()

def area_rect(xmin, x_max, y_min, y_max):
    xmin, x_max, y_min, y_max = float(xmin), float(x_max), float(y_min), float(y_max)
    xDiff = abs(x_max - xmin) # Using absolute value to ignore negatives
    yDiff = abs(y_max - y_min)
    area = xDiff * yDiff
    return area

def collate_fn_voc_detection(data):
    images, labels = [], []
    for el in data:
        names = [i['name'] for i in el[1]['annotation']['object']]
        occu_areas = [area_rect(i['bndbox']['xmin'], i['bndbox']['xmax'], i['bndbox']['ymin'], i['bndbox']['ymax']) for i in el[1]['annotation']['object']]
        gnd_cls = list(voc_cls_maps).index(names[np.argmax(occu_areas)])
        # data_ret.append([el[0], gnd_cls))
        images.append(el[0])
        labels.append(gnd_cls)
    return(torch.stack(images), torch.tensor(labels))



data_transforms = { 'train': transforms.Compose([transforms.ToTensor()]),
                    'val'  : transforms.Compose([transforms.ToTensor(),]) }

data_dir = '/home/datasets'
# image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
#                   for x in ['train', 'val']}
# dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=100, shuffle=True, num_workers=2)
#               for x in ['train', 'val']}

image_datasets = {x: datasets.VOCDetection(root=data_dir, year='2012', image_set=x, download=False, transform=init_transf)
                  for x in ['train', 'val']}

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=128, shuffle=True, collate_fn=collate_fn_voc_detection)
              for x in ['train', 'val']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

# model = models.vit_b_16(weights='IMAGENET1K_V1')
model = models.vit_b_16(weights='IMAGENET1K_SWAG_LINEAR_V1')
num_ftrs = model.heads.head.in_features
model.heads.head = nn.Linear(num_ftrs, 21)
# model.avgpool = nn.AdaptiveAvgPool2d(1)
# num_ftrs = model.fc.in_features
# model.fc = nn.Linear(num_ftrs, 21)

# model.conv1 = nn.Conv2d(3,64, kernel_size=(3,3), stride=(1,1), padding=(1,1))
# model.maxpool = nn.Sequential()

device_id = "cuda:0"
device = torch.device(device_id if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion = nn.CrossEntropyLoss(weight=weights.to(device))
# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

model_ft = train_model(model, dataloaders, dataset_sizes, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=100, device_id = device_id)

#save
torch.save(model_ft.state_dict(), 'models/ViT-finetuned-voc.pt')