"""Functions used to manipulate pytorch tensors and numpy arrays."""

import numbers
from collections import defaultdict
from typing import List, Dict, Optional, DefaultDict, Any

import PIL
import numpy as np
import torch
from PIL import Image
from tensorboardX import SummaryWriter as SummaryWriterBase, summary as tbxsummary
from tensorboardX.proto.summary_pb2 import Summary as TBXSummary
from tensorboardX.utils import _prepare_video as tbx_prepare_video
from tensorboardX.x2num import make_np as tbxmake_np


def to_device_recursively(input: Any, device: str, inplace: bool = True):
    """Recursively places tensors on the appropriate device."""
    if input is None:
        return input
    elif isinstance(input, torch.Tensor):
        return input.to(device)
    elif isinstance(input, tuple):
        return tuple(
            to_device_recursively(input=subinput, device=device, inplace=inplace)
            for subinput in input
        )
    elif isinstance(input, list):
        if inplace:
            for i in range(len(input)):
                input[i] = to_device_recursively(
                    input=input[i], device=device, inplace=inplace
                )
            return input
        else:
            return [
                to_device_recursively(input=subpart, device=device, inplace=inplace)
                for subpart in input
            ]
    elif isinstance(input, dict):
        if inplace:
            for key in input:
                input[key] = to_device_recursively(
                    input=input[key], device=device, inplace=inplace
                )
            return input
        else:
            return {
                k: to_device_recursively(input=input[k], device=device, inplace=inplace)
                for k in input
            }
    elif isinstance(input, set):
        if inplace:
            for element in list(input):
                input.remove(element)
                input.add(
                    to_device_recursively(element, device=device, inplace=inplace)
                )
        else:
            return set(
                to_device_recursively(k, device=device, inplace=inplace) for k in input
            )
    elif isinstance(input, np.ndarray) or np.isscalar(input) or isinstance(input, str):
        return input
    elif hasattr(input, "to"):
        # noinspection PyCallingNonCallable
        return input.to(device=device, inplace=inplace)
    else:
        raise NotImplementedError(
            "Sorry, value of type {} is not supported.".format(type(input))
        )


def detach_recursively(input: Any, inplace=True):
    """Recursively detaches tensors in some data structure from their
    computation graph."""
    if input is None:
        return input
    elif isinstance(input, torch.Tensor):
        return input.detach()
    elif isinstance(input, tuple):
        return tuple(
            detach_recursively(input=subinput, inplace=inplace) for subinput in input
        )
    elif isinstance(input, list):
        if inplace:
            for i in range(len(input)):
                input[i] = detach_recursively(input[i], inplace=inplace)
            return input
        else:
            return [
                detach_recursively(input=subinput, inplace=inplace)
                for subinput in input
            ]
    elif isinstance(input, dict):
        if inplace:
            for key in input:
                input[key] = detach_recursively(input[key], inplace=inplace)
            return input
        else:
            return {k: detach_recursively(input[k], inplace=inplace) for k in input}
    elif isinstance(input, set):
        if inplace:
            for element in list(input):
                input.remove(element)
                input.add(detach_recursively(element, inplace=inplace))
        else:
            return set(detach_recursively(k, inplace=inplace) for k in input)
    elif isinstance(input, np.ndarray) or np.isscalar(input) or isinstance(input, str):
        return input
    elif hasattr(input, "detach_recursively"):
        # noinspection PyCallingNonCallable
        return input.detach_recursively(inplace=inplace)
    else:
        raise NotImplementedError(
            "Sorry, hidden state of type {} is not supported.".format(type(input))
        )


def batch_observations(
    observations: List[Dict], device: Optional[torch.device] = None
) -> Dict[str, torch.Tensor]:
    """Transpose a batch of observation dicts to a dict of batched
    observations.

    # Arguments

    observations :  List of dicts of observations.
    device : The torch.device to put the resulting tensors on.
        Will not move the tensors if None.

    # Returns

    Transposed dict of lists of observations.
    """
    batch: DefaultDict = defaultdict(list)

    for obs in observations:
        for sensor in obs:
            batch[sensor].append(to_tensor(obs[sensor]))

    for sensor in batch:
        batch[sensor] = torch.stack(batch[sensor], dim=0).to(device=device)

    return batch


def to_tensor(v) -> torch.Tensor:
    """Return a torch.Tensor version of the input.

    # Parameters

    v : Input values that can be coerced into being a tensor.

    # Returns

    A tensor version of the input.
    """
    if torch.is_tensor(v):
        return v
    elif isinstance(v, np.ndarray):
        return torch.from_numpy(v)
    else:
        return torch.tensor(
            v, dtype=torch.int64 if isinstance(v, numbers.Integral) else torch.float
        )


def tile_images(images: List[np.ndarray]) -> np.ndarray:
    """Tile multiple images into single image.

    # Parameters

    images : list of images where each image has dimension
        (height x width x channels)

    # Returns

    Tiled image (new_height x width x channels).
    """
    assert len(images) > 0, "empty list of images"
    np_images = np.asarray(images)
    n_images, height, width, n_channels = np_images.shape
    new_height = int(np.ceil(np.sqrt(n_images)))
    new_width = int(np.ceil(float(n_images) / new_height))
    # pad with empty images to complete the rectangle
    np_images = np.array(
        images + [images[0] * 0 for _ in range(n_images, new_height * new_width)]
    )
    # img_HWhwc
    out_image = np_images.reshape((new_height, new_width, height, width, n_channels))
    # img_HhWwc
    out_image = out_image.transpose(0, 2, 1, 3, 4)
    # img_Hh_Ww_c
    out_image = out_image.reshape((new_height * height, new_width * width, n_channels))
    return out_image


class SummaryWriter(SummaryWriterBase):
    def _video(self, tag, vid):
        tag = tbxsummary._clean_tag(tag)
        return TBXSummary(value=[TBXSummary.Value(tag=tag, image=vid)])

    def add_vid(self, tag, vid, global_step=None, walltime=None):
        self._get_file_writer().add_summary(
            self._video(tag, vid), global_step, walltime
        )


def tensor_to_video(tensor, fps=4):
    tensor = tbxmake_np(tensor)
    tensor = tbx_prepare_video(tensor)
    # If user passes in uint8, then we don't need to rescale by 255
    if tensor.dtype != np.uint8:
        tensor = (tensor * 255.0).astype(np.uint8)

    return tbxsummary.make_video(tensor, fps)


class ScaleBothSides(object):
    """Rescales the input PIL.Image to the given 'width' and `height`.

    Attributes
        width: new width
        height: new height
        interpolation: Default: PIL.Image.BILINEAR
    """

    def __init__(self, width: int, height: int, interpolation=Image.BILINEAR):
        self.width = width
        self.height = height
        self.interpolation = interpolation

    def __call__(self, img: PIL.Image) -> PIL.Image:
        return img.resize((self.width, self.height), self.interpolation)
