import random
from enum import Enum
from typing import Tuple, List, Type, Iterator

import PIL
import torch.utils.data.dataset
import torch.utils.data.sampler
from PIL import Image
from torch import Tensor
from torch.nn import functional as F
from torchvision.transforms import transforms


class Base(torch.utils.data.dataset.Dataset):

    class Mode(Enum):
        TRAIN = 'train'
        EVAL = 'eval'

    OPTIONS = ['voc2007', 'coco2017', 'voc2007-cat-dog', 'coco2017-person', 'coco2017-car', 'coco2017-animal']

    @staticmethod
    def from_name(name: str) -> Type['Base']:
        if name == 'voc2007':
            from datasets.voc2007 import VOC2007
            return VOC2007
        elif name == 'coco2017':
            from datasets.coco2017 import COCO2017
            return COCO2017
        elif name == 'voc2007-cat-dog':
            from datasets.voc2007_cat_dog import VOC2007CatDog
            return VOC2007CatDog
        elif name == 'coco2017-person':
            from datasets.coco2017_person import COCO2017Person
            return COCO2017Person
        elif name == 'coco2017-car':
            from datasets.coco2017_car import COCO2017Car
            return COCO2017Car
        elif name == 'coco2017-animal':
            from datasets.coco2017_animal import COCO2017Animal
            return COCO2017Animal
        else:
            raise ValueError

    def __init__(self, path_to_data_dir: str, mode: Mode, image_min_side: float, image_max_side: float):
        self._path_to_data_dir = path_to_data_dir
        self._mode = mode
        self._image_min_side = image_min_side
        self._image_max_side = image_max_side

    def __len__(self) -> int:
        raise NotImplementedError

    def __getitem__(self, index: int) -> Tuple[str, Tensor, Tensor, Tensor, Tensor]:
        raise NotImplementedError

    def evaluate(self, path_to_results_dir: str, image_ids: List[str], bboxes: List[List[float]], classes: List[int], probs: List[float]) -> Tuple[float, str]:
        raise NotImplementedError

    def _write_results(self, path_to_results_dir: str, image_ids: List[str], bboxes: List[List[float]], classes: List[int], probs: List[float]):
        raise NotImplementedError

    @property
    def image_ratios(self) -> List[float]:
        raise NotImplementedError

    @staticmethod
    def num_classes() -> int:
        raise NotImplementedError

    @staticmethod
    def preprocess(image: PIL.Image.Image, image_min_side: float, image_max_side: float) -> Tuple[Tensor, float]:
        # resize according to the rules:
        #   1. scale shorter side to IMAGE_MIN_SIDE
        #   2. after scaling, if longer side > IMAGE_MAX_SIDE, scale longer side to IMAGE_MAX_SIDE
        scale_for_shorter_side = image_min_side / min(image.width, image.height)
        longer_side_after_scaling = max(image.width, image.height) * scale_for_shorter_side
        scale_for_longer_side = (image_max_side / longer_side_after_scaling) if longer_side_after_scaling > image_max_side else 1
        scale = scale_for_shorter_side * scale_for_longer_side

        transform = transforms.Compose([
            transforms.Resize((round(image.height * scale), round(image.width * scale))),  # interpolation `BILINEAR` is applied by default
            # transforms.Resize((224, 224)),  # interpolation `BILINEAR` is applied by default
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        image = transform(image)

        return image, scale

    @staticmethod
    def padding_collate_fn(batch: List[Tuple[str, Tensor, Tensor, Tensor, Tensor]]) -> Tuple[List[str], Tensor, Tensor, Tensor, Tensor]:
        image_id_batch, image_batch, scale_batch, bboxes_batch, labels_batch = zip(*batch)

        max_image_width = max([it.shape[2] for it in image_batch])
        max_image_height = max([it.shape[1] for it in image_batch])
        max_bboxes_length = max([len(it) for it in bboxes_batch])
        max_labels_length = max([len(it) for it in labels_batch])

        padded_image_batch = []
        padded_bboxes_batch = []
        padded_labels_batch = []

        for image in image_batch:
            padded_image = F.pad(input=image, pad=(0, max_image_width - image.shape[2], 0, max_image_height - image.shape[1]))  # pad has format (left, right, top, bottom)
            padded_image_batch.append(padded_image)

        for bboxes in bboxes_batch:
            padded_bboxes = torch.cat([bboxes, torch.zeros(max_bboxes_length - len(bboxes), 4).to(bboxes)])
            padded_bboxes_batch.append(padded_bboxes)

        for labels in labels_batch:
            padded_labels = torch.cat([labels, torch.zeros(max_labels_length - len(labels)).to(labels)])
            padded_labels_batch.append(padded_labels)

        image_id_batch = list(image_id_batch)
        padded_image_batch = torch.stack(padded_image_batch, dim=0)
        scale_batch = torch.stack(scale_batch, dim=0)
        padded_bboxes_batch = torch.stack(padded_bboxes_batch, dim=0)
        padded_labels_batch = torch.stack(padded_labels_batch, dim=0)

        return image_id_batch, padded_image_batch, scale_batch, padded_bboxes_batch, padded_labels_batch

    class NearestRatioRandomSampler(torch.utils.data.sampler.Sampler):

        def __init__(self, image_ratios: List[float], num_neighbors: int):
            super().__init__(data_source=None)
            self._image_ratios = image_ratios
            self._num_neighbors = num_neighbors

        def __len__(self) -> int:
            return len(self._image_ratios)

        def __iter__(self) -> Iterator[int]:
            image_ratios = torch.tensor(self._image_ratios)
            tall_indices = (image_ratios < 1).nonzero().view(-1)
            fat_indices = (image_ratios >= 1).nonzero().view(-1)

            tall_indices_length = len(tall_indices)
            fat_indices_length = len(fat_indices)

            tall_indices = tall_indices[torch.randperm(tall_indices_length)]
            fat_indices = fat_indices[torch.randperm(fat_indices_length)]

            num_tall_remainder = tall_indices_length % self._num_neighbors
            num_fat_remainder = fat_indices_length % self._num_neighbors

            tall_indices = tall_indices[:tall_indices_length - num_tall_remainder]
            fat_indices = fat_indices[:fat_indices_length - num_fat_remainder]

            tall_indices = tall_indices.view(-1, self._num_neighbors)
            fat_indices = fat_indices.view(-1, self._num_neighbors)
            merge_indices = torch.cat([tall_indices, fat_indices], dim=0)
            merge_indices = merge_indices[torch.randperm(len(merge_indices))].view(-1)

            return iter(merge_indices.tolist())