import math
import numbers
import warnings
from enum import Enum
from typing import List, Tuple, Any, Optional

import numpy as np
import torch
from PIL import Image
from torch import Tensor
# from torchvision.transforms.functional import get_image_size, crop, center_crop, resized_crop
from torchvision.transforms.functional import crop, center_crop, resized_crop

try:
    import accimage
except ImportError:
    accimage = None



def five_focus_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
    """Crop the given image into four corners and the central crop.
    If the image is torch Tensor, it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
    .. Note::
        This transform returns a tuple of images and there may be a
        mismatch in the number of inputs and targets your ``Dataset`` returns.
    Args:
        img (PIL Image or Tensor): Image to be cropped.
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
    Returns:
       tuple: tuple (tl, tr, bl, br, center)
       Corresponding top left, top right, bottom left, bottom right and center crop.
    """
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
    elif isinstance(size, (tuple, list)) and len(size) == 1:
        size = (size[0], size[0])

    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")

    w, h = img.size # Returns the size of an image as [width, height].
    # w, h = get_image_size(img) # Returns the size of an image as [width, height].
    # if use get_dimensions(), get [c, h, w]

    crop_boxes = []
    step_h = h // 2 + 1
    step_w = w // 2 + 1
    seps_h = [0, step_h - 1]
    seps_w = [0, step_w - 1]

    for h_idx in seps_h:
        for w_idx in seps_w:
            crop_boxes.append([h_idx, w_idx, step_h, step_w])
    crop_top = int(round((h - step_h) / 2.0))
    crop_left = int(round((w - step_w) / 2.0))
    crop_boxes.append(([crop_top, crop_left, step_h - 1, step_w - 1]))

    tl = resized_crop(img, crop_boxes[0][0], crop_boxes[0][1], crop_boxes[0][2], crop_boxes[0][3], size)
    tr = resized_crop(img, crop_boxes[1][0], crop_boxes[1][1], crop_boxes[1][2], crop_boxes[1][3], size)
    bl = resized_crop(img, crop_boxes[2][0], crop_boxes[2][1], crop_boxes[2][2], crop_boxes[2][3], size)
    br = resized_crop(img, crop_boxes[3][0], crop_boxes[3][1], crop_boxes[3][2], crop_boxes[3][3], size)
    center = resized_crop(img, crop_boxes[4][0], crop_boxes[4][1], crop_boxes[4][2], crop_boxes[4][3], size)

    return tl, tr, bl, br, center


def nine_focus_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
    """Crop the given image into nine crops.
    If the image is torch Tensor, it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
    .. Note::
        This transform returns a tuple of images and there may be a
        mismatch in the number of inputs and targets your ``Dataset`` returns.
    Args:
        img (PIL Image or Tensor): Image to be cropped.
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
    Returns:
       tuple: tuple (crop1, ..., crop9)
       Corresponding top left, top right, bottom left, bottom right and center crop.
    """
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
    elif isinstance(size, (tuple, list)) and len(size) == 1:
        size = (size[0], size[0])

    if len(size) != 2:
        raise ValueError("Please provide only two dimensions (h, w) for size.")

    w, h = img.size # Returns the size of an image as [width, height].
    # w, h = get_image_size(img) # Returns the size of an image as [width, height].
    # if use get_dimensions(), get [c, h, w]

    crop_boxes = []
    step_h = h // 3 + 1
    step_w = w // 3 + 1
    seps_h = [0, step_h - 1, (step_h - 1) * 2]
    seps_w = [0, step_w - 1, (step_w - 1) * 2]
    for h_idx in seps_h:
        for w_idx in seps_w:
            crop_boxes.append([h_idx, w_idx, step_h, step_w])
    crop_imgs = []
    for box in crop_boxes:
        crop_imgs.append(resized_crop(img, box[0], box[1], box[2], box[3], size))

    return crop_imgs



