"""
CUB dataset with "raw" images and caption features.

Image modality x1: 3 x 128 x 128
Sentence modality: (1024,) features
Train: 6,557 (111 classes)
Val: 2,298 (39 classes)
Test: 2,933 (50 classes)
"""
import logging
import os
from collections import defaultdict
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms

import utils
from data.cub.main_ft import _load_multimodal_data
from data.cub.preprocess import preprocess_cub
from hyperparams.load import get_config

logger = logging.getLogger('custom')
config = get_config()
cub_path = config.dirs['cub_standard']


class CubDataset(Dataset):
    def __init__(self, x1, x2, y, caption_paths, name2bbox):
        self.x = [x1, x2]
        self.transform = transforms.ToTensor()
        self.name2bbox = name2bbox
        # supplementary information
        self.s = {'y': utils.to_torch(y, dtype=torch.int32)}
        self.len = len(self.s['y'])
        assert x1.shape[0] == x2.size(0) == self.len
        self.s['caption_paths'] = caption_paths

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        x = [v[idx] for v in self.x]
        s = {k: v[idx] for k, v in self.s.items()}
        x1 = x[0]  # path to image
        x1 = self.open_image(x1)
        x[0] = x1
        return x, s

    def open_image(self, path, **kwargs):
        bbox = self.name2bbox[Path(path).stem]
        image = Image.open(path).convert('RGB')
        image = _crop(image, bbox)
        image = _transform(image, **kwargs)
        return image


def _crop(im, bbox):
    # https://github.com/hanzhanggit/StackGAN-v2
    w, h = im.width, im.height
    r = int(np.maximum(bbox[2], bbox[3]) * 0.75)
    center_x = int((2 * bbox[0] + bbox[2]) / 2)
    center_y = int((2 * bbox[1] + bbox[3]) / 2)
    y1 = np.maximum(0, center_y - r)
    y2 = np.minimum(h, center_y + r)
    x1 = np.maximum(0, center_x - r)
    x2 = np.minimum(w, center_x + r)
    im = im.crop([x1, y1, x2, y2])
    return im


def _transform(im, size=64):
    # https://github.com/hanzhanggit/StackGAN-v2
    transform = transforms.Compose([transforms.Resize(int(size * 76 / 64)),
                                    transforms.CenterCrop(size),
                                    transforms.ToTensor()])
    im = transform(im)
    return im


def load_cub_data(mode, batch_size=64, **kwargs):
    data = prepare_cub_data(**kwargs)
    dataset = _create_dataset(data, mode)
    loader = _create_loader(dataset, batch_size, mode)
    return dataset, loader


def _create_loader(dataset, batch_size, mode):
    loader = DataLoader(dataset, batch_size,
                        shuffle=mode == 'train',
                        pin_memory=True)
    return loader


def _create_dataset(data, mode='train'):
    loc = data['loc'][mode]
    dataset = CubDataset(
        x1=data['image_paths'][loc],
        x2=data['caption_features'][loc],
        y=data['y'][loc],
        caption_paths=data['caption_paths'][loc],
        name2bbox=data['name2bbox']
    )
    return dataset


def prepare_cub_data(**kwargs):
    """
    Collects and preprocesses caption data and image meta data.
    """
    data = {}

    data = _get_image_paths(data)
    data = _get_bounding_boxes(data)
    data = _get_label_information(data)

    data_ft = _load_multimodal_data()
    data = _get_split_information(data, data_ft)
    data = _align_images_and_captions(data, data_ft)
    data = preprocess_cub(data, **kwargs)

    return data


def _get_bounding_boxes(data):
    bbox_path = os.path.join(cub_path, 'bounding_boxes.txt')
    bbox = pd.read_csv(bbox_path,
                       delim_whitespace=True,
                       header=None).astype(int)
    bbox = bbox.iloc[:, 1:]  # Only select columns with bbox info
    bbox = np.array(bbox)
    im_path = os.path.join(cub_path, 'images.txt')
    im = pd.read_csv(im_path, header=None).values
    data['name2bbox'] = {}
    for cur_bbox, cur_im in zip(bbox, im):
        n = Path(cur_im[0]).stem
        data['name2bbox'][n] = cur_bbox
    return data


def _align_images_and_captions(data, data_ft):
    # map image name to caption location in 'data_ft' space
    image2idx = {}
    for idx, p in enumerate(data_ft['image_paths']):
        n = Path(p).name
        image2idx[n] = idx

    # reorder captions to 'data' space
    data['caption_features'] = []
    data['caption_paths'] = []
    for p in data['image_paths']:
        n = Path(p).name
        idx = image2idx[n]
        cf = data_ft['caption_features'][idx]
        data['caption_features'].append(cf)
        cp = data_ft['caption_paths'][idx]
        data['caption_paths'].append(cp)
    data['caption_features'] = torch.stack(data['caption_features'])
    data['caption_paths'] = np.array(data['caption_paths'])

    return data


def _get_label_information(data):
    classes = pd.read_csv(os.path.join(config.dirs['cub_standard'], 'classes.txt'),
                          delimiter=' ', header=None)
    classes = dict(zip(classes[1], classes[0]))
    classes = {k[:3]: v for k, v in classes.items()}

    data['y'] = []
    for f in data['image_paths']:
        name = Path(f).parent.name
        data['y'].append(classes[name[:3]])
    data['y'] = np.array(data['y'])

    return data


def _get_split_information(data, data_ft):
    # map image name to split
    name2loc = {}
    for l in ['train', 'val', 'test']:
        names = data_ft['image_paths'][data_ft['loc'][l]]
        for n in names:
            name2loc[Path(n).name] = l

    # apply to new data
    data['loc'] = defaultdict(list)
    for idx, f in enumerate(data['image_paths']):
        name = Path(f).name
        split = name2loc[name]
        if split == 'train':
            data['loc']['train'].append(idx)
        elif split == 'val':
            data['loc']['val'].append(idx)
        elif split == 'test':
            data['loc']['test'].append(idx)
        else:
            raise Exception('Illegal split.')
    for k, v in data['loc'].items():
        data['loc'][k] = np.array(v)

    return data


def _get_image_paths(data):
    # create list with image paths
    data['image_paths'] = []
    for root, dirs, files in os.walk(os.path.join(cub_path, 'images')):
        for f in files:
            path = os.path.join(root, f)
            data['image_paths'].append(path)
    data['image_paths'] = np.array(data['image_paths'])
    return data
