import PIL
import os
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm import tqdm

import utils
from data.flowers.main_raw import _load_original_flowers_data
from hyperparams.load import get_config


class FlowersDataset(Dataset):
    def __init__(self, x, y, transform=None):
        self.x = x
        self.y = y
        self.transform = transform
        self.len = len(self.x)

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        x = self.x[idx]
        x = PIL.Image.open(x).convert('RGB')
        if self.transform:
            x = self.transform(x)
        y = self.y[idx]
        return x, y


def main(device):
    resnet = _load_resnet(device)
    data = _load_original_flowers_data(size=None)
    dataset = _create_dataset(data)
    ft = _extract_features(dataset, resnet, device)
    return ft


@torch.no_grad()
def _extract_features(dataset, resnet, device):
    loader = torch.utils.data.DataLoader(dataset, batch_size=256,
                                         shuffle=False)

    ft = []
    for cur_x, cur_y in tqdm(loader):
        print()
        cur_x = resnet(cur_x.to(device)).squeeze()
        ft.append(cur_x)
    ft = torch.cat(ft).cpu()
    save_path = os.path.join(config.dirs['flowers_images'], 'resnet_features.pt')
    torch.save(ft, save_path)
    print(f'Extracted resnet features and saved them at "{save_path}".')
    return ft


def _create_dataset(data):
    # https://github.com/pytorch/examples/blob/97304e232807082c2e7b54c597615dc0ad8f6173/imagenet/main.py#L197-L198
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    tx = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    dataset = FlowersDataset(x=data['image_paths'], y=data['y'], transform=tx)
    return dataset


def _load_resnet(device):
    resnet = torchvision.models.resnet101(pretrained=True)
    modules = list(resnet.children())[:-1]
    resnet = nn.Sequential(*modules)
    resnet = resnet.to(device).eval()
    return resnet


if __name__ == '__main__':
    config = get_config()
    device = utils.setup_device()
    ft = main(device)
