"""
Flowers dataset with "raw" images and caption features.

Image modality x1: 3x128x128
Sentence modality x2: (1024,) features

Train: 5,490 (62 classes)
Val: 1,544 (20 classes)
Test: 1,155 (20 classes)
(exact numbers may vary depending on the sampling-seed for the train/val-split)
"""
import inspect
import logging
import os
from collections import defaultdict
from pathlib import Path

import PIL
import numpy as np
import torch
from scipy.io import loadmat
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms

import utils
from data.flowers.preprocess import preprocess_flowers_data
from data.load_captions import load_caption_data
from hyperparams.load import get_config

logger = logging.getLogger('custom')
config = get_config()


class FlowersDataset(Dataset):
    def __init__(self, x1, x2, y, caption_paths):
        self.x = [x1, x2]

        # Supplementary information
        self.s = {'y': utils.to_torch(y, dtype=torch.int32)}
        self.len = len(self.s['y'])
        assert x1.shape[0] == x2.size(0) == self.len
        self.s['caption_paths'] = caption_paths

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        x = [v[idx] for v in self.x]
        s = {k: v[idx] for k, v in self.s.items()}
        x1 = x[0]  # path to image
        x1 = self.open_image(x1)
        x[0] = x1
        return x, s

    @staticmethod
    def open_image(path):
        image = PIL.Image.open(path).convert('RGB')
        transform = transforms.Compose([transforms.ToTensor()])
        image = transform(image)
        return image

    def get_image_tensors(self, idx=None):
        """ Get image tensors (not their paths) """
        if not idx:
            idx = list(range(self.len))
        imgs = []
        for i in idx:
            img = self.open_image(self.x[0][i])
            imgs.append(img)
        imgs = torch.stack(imgs)
        return imgs


def load_flowers_data(mode, batch_size=64, size=64, **kwargs):
    data = _prepare_flowers_data(size)
    data = preprocess_flowers_data(data, **kwargs)
    dataset = _create_dataset(data, mode)
    loader = _create_loader(dataset, batch_size, mode)
    return dataset, loader


def _create_loader(dataset, batch_size, mode):
    loader = DataLoader(dataset, batch_size,
                        shuffle=mode == 'train',
                        pin_memory=True)
    return loader


def _create_dataset(data, mode='train'):
    loc = data['loc'][mode]
    dataset = FlowersDataset(
        x1=data['image_paths'][loc],
        x2=data['caption_features'][loc],
        y=data['y'][loc],
        caption_paths=data['caption_paths'][loc],
    )
    return dataset


def _prepare_flowers_data(size):
    data = _load_original_flowers_data(size)
    caption_data = load_caption_data(dataset='flowers')
    caption_data['loc'] = _get_split(caption_data)
    data = _integrate_caption_loc(data, caption_data)
    data = _integrate_caption_feats(data, caption_data)
    del caption_data
    _assert_no_class_overlap(data)
    return data


def _assert_no_class_overlap(data):
    tr = np.unique(data['y'][data['loc']['train']])
    val = np.unique(data['y'][data['loc']['val']])
    test = np.unique(data['y'][data['loc']['test']])
    assert set(tr) & set(val) == set()
    assert set(tr) & set(test) == set()
    assert set(val) & set(test) == set()


def _integrate_caption_feats(data, caption_data):
    assert np.array(caption_data['names']).shape == \
           np.unique(np.array(caption_data['names'])).shape, \
        'Make data unique to securely use .index() function later.'
    data['caption_features'] = []
    data['caption_paths'] = []
    for p in data['image_paths']:
        n = Path(p).name[:-4]  # without extension
        idx = caption_data['names'].index(n)
        data['caption_features'].append(caption_data['emb'][idx])
        data['caption_paths'].append(caption_data['paths'][idx])
    data['caption_features'] = torch.stack(data['caption_features'])
    data['caption_paths'] = np.array(data['caption_paths'])
    return data


def _integrate_caption_loc(data, caption_data):
    # map image name to split
    name2loc = {}
    for l in ['train', 'val', 'test']:
        names = np.array(caption_data['names'])[caption_data['loc'][l]]
        for n in names:
            name2loc[Path(n).name + '.jpg'] = l

    # apply to new data
    data['loc'] = defaultdict(list)
    for idx, f in enumerate(data['image_paths']):
        name = Path(f).name
        split = name2loc[name]
        if split == 'train':
            data['loc']['train'].append(idx)
        elif split == 'val':
            data['loc']['val'].append(idx)
        elif split == 'test':
            data['loc']['test'].append(idx)
        else:
            raise Exception('Illegal split.')
    for k, v in data['loc'].items():
        data['loc'][k] = np.array(v)

    return data


def _get_split(caption_data):
    # split is saved in repository to ensure reproducibility
    code_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
    split_path = os.path.join(code_dir, 'split.pt')

    if os.path.isfile(split_path):
        loc = torch.load(split_path)
    else:
        utils.set_seeds(23)
        loc = _create_split(caption_data)
        torch.save(loc, split_path)
        logger.info(f'Created new split at "{split_path}".')

    return loc


def _create_split(caption_data):
    # Select held-out val-classes
    loc_tr = caption_data['loc']['train']
    y_tr = caption_data['y'][loc_tr].unique()
    y_val = np.random.choice(utils.to_np(y_tr), 20, replace=False)
    loc_val = torch.cat([torch.where(cur_y == caption_data['y'])[0] for cur_y in y_val])
    loc_val = loc_val.tolist()
    caption_data['loc']['train'] = [v for v in loc_tr if v not in loc_val]
    caption_data['loc']['val'] = loc_val

    # Assert no overlap
    loc_tr = caption_data['loc']['train']
    loc_val = caption_data['loc']['val']
    loc_test = caption_data['loc']['test']
    assert set(loc_tr) & set(loc_val) == set()
    assert set(loc_tr) & set(loc_test) == set()
    assert set(loc_val) & set(loc_test) == set()

    # Assert class sizes
    assert caption_data['y'][caption_data['loc']['train']].unique().size(0) == 62
    assert caption_data['y'][caption_data['loc']['val']].unique().size(0) == 20
    assert caption_data['y'][caption_data['loc']['test']].unique().size(0) == 20

    return caption_data['loc']


def _load_original_flowers_data(size=64):
    data = {}

    data_root = config.dirs['flowers_images']

    # Labels
    imagelabels = loadmat(os.path.join(data_root, 'imagelabels.mat'))
    data['y'] = imagelabels['labels'].squeeze()

    # Image paths
    data['image_paths'] = []
    if size:
        d = os.path.join(data_root, f'transformed_jpg_{size}')
    else:
        d = os.path.join(data_root, f'jpg')
    # There are 8189 pictures named in ascending order starting from 1
    for i in range(1, 8190):
        path = os.path.join(d, f'image_{i:05d}.jpg')
        data['image_paths'].append(path)
    data['image_paths'] = np.array(data['image_paths'])

    return data


if __name__ == '__main__':
    utils.set_logger(verbosity=10)
    datasets, loaders = load_flowers_data(mode='train', batch_size=64)
