import colorsys
import math
import random

from einops import rearrange
import torch as pt
import torch.nn.functional as ptnf
import torchvision.transforms as ptvt
import torchvision.transforms.v2 as ptvt2

from ..utils import unsqueeze_to, DictTool


class Compose:

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, **kwds):
        for t in self.transforms:
            kwds = t(**kwds)
        return kwds

    def __repr__(self) -> str:
        format_string = self.__class__.__name__ + "("
        for t in self.transforms:
            format_string += "\n"
            format_string += f"    {t}"
        format_string += "\n)"
        return format_string


class Filter:
    """filter out values not in ``keys``"""

    def __init__(self, keys):
        self.keys = keys

    def __call__(self, **pack: dict) -> dict:
        pack2 = {}
        for key in self.keys:
            DictTool.setattr(pack2, key, DictTool.getattr(pack, key))
        return pack2


class Normalize:
    """support any tensor shape, as long as mean and std broadcastable"""

    def __init__(self, keys, mean=None, std=None):
        self.keys = keys
        self.mean = pt.as_tensor(mean)
        self.std = pt.as_tensor(std)

    def __call__(self, **pack: dict) -> dict:
        for key in self.keys:
            input = DictTool.getattr(pack, key)
            mean = input.mean() if self.mean is None else self.mean
            std = input.std() if self.std is None else self.std
            output = (input - mean) / std
            DictTool.setattr(pack, key, output)
        return pack


class Logarithm:

    def __init__(self, keys, base=2, delta=1):
        self.keys = keys
        self.log_base = pt.log(pt.tensor(base))
        self.delta = delta

    def __call__(self, **pack: dict) -> dict:
        for key in self.keys:
            input = DictTool.getattr(pack, key)
            output = pt.log(self.delta + input) / self.log_base
            DictTool.setattr(pack, key, output)
        return pack


class Concat:

    def __init__(self, src_keys, dst_key, dim):
        self.src_keys = src_keys
        self.dst_key = dst_key
        self.dim = dim

    def __call__(self, **pack: dict) -> dict:
        sources = [DictTool.getattr(pack, _) for _ in self.src_keys]
        destin = pt.cat(sources, dim=self.dim)
        DictTool.setattr(pack, self.dst_key, destin)
        return pack


class Rearrange:

    def __init__(self, keys, pattern):
        self.keys = keys
        self.pattern = pattern

    def __call__(self, **pack: dict) -> dict:
        for key in self.keys:
            input = DictTool.getattr(pack, key)
            output = rearrange(input, self.pattern)
            DictTool.setattr(pack, key, output)
        return pack


class Resize:
    """tensor shape must be (...,c,h,w)"""

    def __init__(
        self, keys, size=None, scale=None, interp="bilinear", max_size=None, c=True
    ):
        self.keys = keys
        self.size = size  # (h, w)
        self.scale = scale  # int
        self.interp = interp
        # self.max_size = max_size  # TODO max_size
        self.c = c  # input has c or not

    def __call__(self, **pack: dict) -> dict:
        for key in self.keys:
            input = DictTool.getattr(pack, key)
            flag1 = input.dtype in [pt.int16, pt.int32, pt.int64]
            if flag1:  # cast to uint8
                dtype = input.dtype
                input = input.to(pt.uint8)
            if not self.c:
                input = input[..., None, :, :]  # (..,c=1,h,w)
            flag2 = input.ndim >= 3
            if flag2:  # flatten ``...`` dims into batch dim
                shape0 = input.shape[:-3] + (1,)
                input = input[..., None, :, :, :].flatten(0, -4)
            assert input.ndim == 4
            if (self.size and list(self.size) == list(input.shape[-2:])) or (
                self.scale and self.scale == 1
            ):  # skip interploate if self.size or self.scale not change
                output = input
            else:
                output = ptnf.interpolate(input, self.size, self.scale, self.interp)
            if flag2:
                output = output.unflatten(0, shape0).squeeze(-4)
            if not self.c:
                output = output[..., 0, :, :]
            if flag1:
                output = output.to(dtype)
            DictTool.setattr(pack, key, output)
        return pack


class Flatten:
    """keyed tensors should have same ndim"""

    def __init__(self, keys, start=0, end=-1):
        self.keys = keys
        self.start = start
        self.end = end

    def __call__(self, **pack: dict) -> dict:
        for key in self.keys:
            value0 = DictTool.getattr(pack, key)
            value = value0.flatten(self.start, self.end)
            DictTool.setattr(pack, key, value)
        return pack


class PadTo1:
    """keyed tensors should have same size in self.dim"""

    def __init__(self, keys, dim, length, mode="right", value=0):
        """
        mask_key: suffix to ``keys`` as new mask keys
        mode: ``left``, ``sides`` (center), ``right``
        """
        self.keys = keys
        self.dim = dim
        self.length = length
        self.mode = mode
        self.value = value

    def __call__(self, **pack: dict) -> dict:
        for key in self.keys:
            value0 = DictTool.getattr(pack, key)
            size = value0.size(self.dim)
            if self.length <= size:
                continue
            left, right = __class__.calc_padding(self.length, size, self.mode)
            value = __class__.pad1(value0, self.dim, left, right, self.value)
            assert value.size(self.dim) == self.length
            DictTool.setattr(pack, key, value)
        return pack

    @staticmethod
    def calc_padding(target, size, mode):
        if mode == "left":
            left = target - size
        elif mode == "sides":
            left = (target - size) // 2
        elif mode == "right":
            left = 0
        else:
            raise "ValueError"
        right = target - size - left
        return left, right

    @staticmethod
    def pad1(input, dim, left, right, pad_value=0):
        """from the last dim to first"""
        pad = [0, 0] * (input.ndim - dim - 1) + [left, right] + [0, 0] * dim
        return ptnf.pad(input, pad, value=pad_value)


class CatTokenToSlot:
    """use before PadSlot, in case fg_slot==bg_slot"""

    def __init__(self, slot_key, token=1):
        self.slot_key = slot_key
        self.token = token

    def __call__(self, **pack: dict) -> dict:
        """
        input: in shape (t,n,c)
        output: in shape (t,n,c+1)
        """
        value0 = DictTool.getattr(pack, self.slot_key)
        t, n, c = value0.shape
        value = pt.cat(
            [value0, pt.ones(t, n, 1, dtype=value0.dtype) * self.token], dim=2
        )
        DictTool.setattr(pack, self.slot_key, value)
        return pack


class PadSlot:
    """TODO XXX try bg=-1, null=0"""

    def __init__(
        self, slot_key, max_num, mask_key=None, background=0.0, null=0.0, t=True
    ):
        self.slot_key = slot_key
        self.max_num = max_num  # background + entities + null
        self.mask_key = mask_key
        self.background = background  # None: not pad background
        self.null = null
        self.t = t

    def __call__(self, **pack: dict) -> dict:
        """
        input: in shape (t,n,c)
        output: in shape (t,m,c)
        """
        slot = DictTool.getattr(pack, self.slot_key)
        if not self.t:
            assert slot.ndim == 2  # (n,c)
            slot = slot[None, :, :]
        if self.mask_key is not None:
            mask = DictTool.getattr(pack, self.mask_key)
            if not self.t:
                assert mask.ndim == 2  # (n,c)
                mask = mask[None, :, :]
        t, n, c = slot.shape
        if self.mask_key is not None:
            assert [t, n] == list(mask.shape)
        assert self.max_num >= (n if self.background is None else 1 + n)
        m = self.max_num - (n if self.background is None else 1 + n)
        parts_slot = [slot]
        if self.mask_key is not None:
            parts_mask = [mask]
        if self.background is not None:
            parts_slot.insert(0, pt.ones([t, 1, c], dtype=slot.dtype) * self.background)
            if self.mask_key is not None:
                parts_mask.insert(0, pt.ones([t, 1], dtype=mask.dtype))
        parts_slot.append(pt.ones([t, m, c], dtype=slot.dtype) * self.null)
        if self.mask_key is not None:
            parts_mask.append(pt.zeros([t, m], dtype=mask.dtype))
        slot2 = pt.cat(parts_slot, dim=1)
        if not self.t:
            slot2 = slot2[0, :, :]
        DictTool.setattr(pack, self.slot_key, slot2)
        if self.mask_key is not None:
            mask2 = pt.cat(parts_mask, dim=1)
            if not self.t:
                mask2 = mask2[0, :, :]
            DictTool.setattr(pack, self.mask_key, mask2)
        return pack


class Slice1:
    """keyed tensors should have same size in self.dim"""

    def __init__(self, keys, dim, start, end, step=None):
        self.keys = keys
        self.dim = dim
        self.start = start
        self.end = end
        self.step = step

    def __call__(self, **pack: dict) -> dict:
        for key in self.keys:
            value0 = DictTool.getattr(pack, key)
            value = __class__.slice1(value0, self.dim, self.start, self.end, self.step)
            DictTool.setattr(pack, key, value)
        return pack

    @staticmethod
    def slice1(x, dim, start, end, step):
        start = start or ""
        end = end or ""
        step = step or ""
        prefix = ",".join([":"] * dim)
        if prefix:
            prefix += ","
        op_str = f"x[{prefix}{start}:{end}:{step},...]"
        x = eval(compile(op_str, "", "eval"))
        return x


class RandomSlice1:
    """keyed tensors should have same size in self.dim"""

    def __init__(self, keys, dim, length, step=None):
        self.keys = keys
        self.dim = dim
        self.length = length
        self.step = step

    def __call__(self, **pack: dict) -> dict:
        for key in self.keys:
            value0 = DictTool.getattr(pack, key)
            size = value0.size(self.dim)
            if self.length >= size:
                continue
            start, end = __class__.calc_slicing(self.length, size)
            value = Slice1.slice1(value0, self.dim, start, end, self.step)
            DictTool.setattr(pack, key, value)
        return pack

    @staticmethod
    def calc_slicing(target, size):
        start = random.randint(0, size - target)
        end = start + target
        return start, end


class StridedRandomSlice1(RandomSlice1):
    """no overlap between slices"""

    def __call__(self, **pack: dict) -> dict:
        for key in self.keys:
            value0 = DictTool.getattr(pack, key)
            size = value0.size(self.dim)
            if self.length >= size:
                continue
            start, end = __class__.calc_slicing(self.length, size)
            value = Slice1.slice1(value0, self.dim, start, end, self.step)
            DictTool.setattr(pack, key, value)
        return pack

    @staticmethod
    def calc_slicing(target, size):
        start = random.randint(0, math.ceil(size / target) - 1) * target
        assert size % target == 0  # TODO XXX remove this restrict
        end = start + target
        return start, end


class SliceTo1:
    """"""

    def __init__(self, keys, dim, length, step=None, mode="center"):
        """
        mode: left, center, right
        """
        self.keys = keys
        self.dim = dim
        self.length = length
        self.step = step or 1
        self.mode = mode

    def __call__(self, **pack: dict) -> dict:
        for key in self.keys:
            value0 = DictTool.getattr(pack, key)
            size = value0.size(self.dim)
            # print(size)
            # assert size > 50
            if size <= self.length and self.step == 1:
                continue
            start, end = __class__.calc_slicing(self.length, size, self.mode)
            value = Slice1.slice1(value0, self.dim, start, end, self.step)
            DictTool.setattr(pack, key, value)
        return pack

    @staticmethod
    def calc_slicing(target, size, mode):
        assert target <= size
        if mode == "left":
            start = 0
        elif mode == "center":
            start = (size - target) // 2
        elif mode == "right":
            start = size - target
        else:
            raise "ValueError"
        end = start + target
        return start, end


# class RandomResizedCrop:
#     """Crop the given image to random scale and aspect ratio. Support tensor shape (..,c,h,w)."""

#     def __init__(
#         self,
#         keys,
#         size=(224, 224),
#         crop_ratio_range=(0.08, 1.0),
#         aspect_ratio_range=(3.0 / 4.0, 4.0 / 3.0),
#         max_attempts=5,
#     ):
#         self.keys = keys
#         self.size = size
#         self.crop_ratio_range = crop_ratio_range
#         self.aspect_ratio_range = aspect_ratio_range
#         self.max_attempts = max_attempts

#     def rand_crop_params(self, image):
#         h, w = image.shape[-2:]
#         area = h * w

#         for _ in range(self.max_attempts):
#             target_area = random.uniform(*self.crop_ratio_range) * area
#             log_ratio = (
#                 math.log(self.aspect_ratio_range[0]),
#                 math.log(self.aspect_ratio_range[1]),
#             )
#             aspect_ratio = math.exp(random.uniform(*log_ratio))
#             target_w = int(round(math.sqrt(target_area * aspect_ratio)))
#             target_h = int(round(math.sqrt(target_area / aspect_ratio)))

#             if 0 < target_w <= w and 0 < target_h <= h:
#                 offset_h = random.randint(0, h - target_h)
#                 offset_w = random.randint(0, w - target_w)
#                 return offset_h, offset_w, target_h, target_w

#         # Fallback to central crop
#         in_ratio = float(w) / float(h)
#         if in_ratio < min(self.aspect_ratio_range):
#             target_w = w
#             target_h = int(round(target_w / min(self.aspect_ratio_range)))
#         elif in_ratio > max(self.aspect_ratio_range):
#             target_h = h
#             target_w = int(round(target_h * max(self.aspect_ratio_range)))
#         else:  # whole image
#             target_w = w
#             target_h = h

#         offset_h = (h - target_h) // 2
#         offset_w = (w - target_w) // 2
#         return offset_h, offset_w, target_h, target_w


#     def __call__(self, **pack: dict) -> dict:
#         for key in self.keys:
#             value0 = DictTool.getattr(pack, key)
#             offset_h, offset_w, target_h, target_w = self.rand_crop_params(value0)
#             value = value0[
#                 ...,
#                 offset_h : offset_h + target_h - 1,
#                 offset_w : offset_w + target_w - 1,
#             ]
#             if value.ndim < 4:
#                 assert value.ndim >= 3
#                 value = ptnf.interpolate(value[None], self.size)[0]
#             else:
#                 value = ptnf.interpolate(value, self.size)
#             DictTool.setattr(pack, key, value)
#         return pack
class RandomResizedCrop:
    """Support image in shape (..,c,h,w) and bbox in shape (..,4).
    For ``keys``, must place image key first."""

    INTERPS = {_.value: _ for _ in ptvt2.InterpolationMode}

    def __init__(
        self, keys, size, scale=(0.08, 1.0), ratio=(3 / 4, 4 / 3), interp="bilinear"
    ):
        self.keys = keys
        self.size = size
        self.scale = scale
        self.ratio = ratio
        if isinstance(interp, str):
            interp = [interp] * len(keys)
        elif isinstance(interp, (list, tuple)):
            assert len(interp) == len(keys)
            interp = interp
        else:
            raise "ValueError"
        self.interp = [__class__.INTERPS[_] for _ in interp]

    @staticmethod
    def get_params(image, ratio, scale) -> dict:
        height, width = image.shape[-2:]
        area = height * width
        log_ratio = pt.log(pt.as_tensor(ratio))

        for _ in range(10):
            target_area = area * pt.empty(1).uniform_(scale[0], scale[1]).item()
            aspect_ratio = pt.exp(
                pt.empty(1).uniform_(log_ratio[0], log_ratio[1])
            ).item()

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if 0 < w <= width and 0 < h <= height:
                i = pt.randint(0, height - h + 1, size=(1,)).item()
                j = pt.randint(0, width - w + 1, size=(1,)).item()
                break

        else:  # Fallback to central crop
            in_ratio = float(width) / float(height)
            if in_ratio < min(ratio):
                w = width
                h = int(round(w / min(ratio)))
            elif in_ratio > max(ratio):
                h = height
                w = int(round(h * max(ratio)))
            else:  # whole image
                w = width
                h = height
            i = (height - h) // 2
            j = (width - w) // 2

        return dict(top=i, left=j, height=h, width=w)

    def __call__(self, **pack: dict) -> dict:
        image = DictTool.getattr(pack, self.keys[0])
        params = self.get_params(image, self.ratio, self.scale)
        for i, key in enumerate(self.keys):
            input = DictTool.getattr(pack, key)
            output = ptvt2.functional.resized_crop(
                input,
                **params,
                size=self.size,
                interpolation=self.interp[i],
                antialias=(self.interp[i] != "nearest-exact"),
            )
            DictTool.setattr(pack, key, output)
        return pack


class CenterResizedCrop:
    """Center crop. Support tensor shape (..,c,h,w)"""

    def __init__(self, keys, size: int, crop_padding: int = 32):
        self.keys = keys
        assert len(size) == 2 and len(set(size)) == 1
        self.size = size
        self.crop_padding = crop_padding

    def __call__(self, **pack: dict) -> dict:
        for key in self.keys:
            value0 = DictTool.getattr(pack, key)
            h, w = value0.shape[-2:]
            crop_size, offset_h, offset_w = __class__.calc_cropping(
                h, w, self.size, self.crop_padding
            )
            value = value0[
                ...,
                offset_h : offset_h + crop_size - 1,
                offset_w : offset_w + crop_size - 1,
            ]
            value = ptnf.interpolate(value[None], self.size)[0]
            DictTool.setattr(pack, key, value)
        return pack

    @staticmethod
    def calc_cropping(h, w, target, crop_padding):
        small_side = min(h, w)
        crop_size = int(target[0] / (target[0] + crop_padding) * small_side)  # TODO XXX
        offset_h = max(0, int(round((h - crop_size) / 2.0)))
        offset_w = max(0, int(round((w - crop_size) / 2.0)))
        return crop_size, offset_h, offset_w


class RandomFlip:
    """"""

    def __init__(self, keys, dims, prob=0.5):
        self.keys = keys
        self.dims = dims
        self.prob = prob

    def __call__(self, **pack: dict) -> dict:
        if random.random() > self.prob:
            return pack
        dim = random.choice(self.dims)
        for key in self.keys:
            value0 = DictTool.getattr(pack, key)
            value = value0.flip(dim)
            DictTool.setattr(pack, key, value)
        return pack


class Mask:
    """Support adaptively unsqueezing ``mask`` to ``input``.
    Assume ``mask.ndim <= input.ndim`` so adaptive unsqueeze is needed.
    """

    def __init__(self, keys, mask_key, keep=True):
        self.keys = keys
        self.mask_key = mask_key
        # self.keep = keep

    def __call__(self, **pack: dict) -> dict:
        mask = DictTool.getattr(pack, self.mask_key)
        for key in self.keys:
            value0 = DictTool.getattr(pack, key)
            value = value0 * unsqueeze_to(mask.to(value0), value0)
            DictTool.setattr(pack, key, value)
        return pack


class Clip:
    """"""

    def __init__(self, keys, min=None, max=None):
        self.keys = keys
        self.min = min
        self.max = max

    def __call__(self, **pack: dict) -> dict:
        for key in self.keys:
            value0 = DictTool.getattr(pack, key)
            value = pt.clip(value0, self.min, self.max)
            DictTool.setattr(pack, key, value)
        return pack


class ToDevice:
    """"""

    def __init__(self, keys, device="cuda"):
        self.keys = keys
        self.device = device

    def __call__(self, **pack: dict) -> dict:
        for key in self.keys:
            value0 = DictTool.getattr(pack, key)
            value = value0.to(self.device)
            DictTool.setattr(pack, key, value)
        return pack


class Detach:
    """"""

    def __init__(self, keys):
        self.keys = keys

    def __call__(self, **pack: dict) -> dict:
        for key in self.keys:
            value0 = DictTool.getattr(pack, key)
            value = value0.detach()
            DictTool.setattr(pack, key, value)
        return pack


class TupleToIndex:

    def __init__(self, keys, groups, dim):
        self.keys = keys
        self.groups = groups
        self.dim = dim

    def __call__(self, **pack: dict) -> dict:
        for key in self.keys:
            value0 = DictTool.getattr(pack, key)
            value = __class__.tuple_to_number(value0, self.groups, self.dim)
            DictTool.setattr(pack, key, value)
        return pack

    @staticmethod
    def tuple_to_number(tidx, groups: list, dim=1):
        """
        tidx: shape=(b,g,..)
        """
        nidx = 0
        device = tidx.device
        base = 1
        for i, g in enumerate(groups):
            idx_g = tidx.index_select(dim, pt.as_tensor(i, device=device)).squeeze(dim)
            nidx += idx_g * base
            base *= g
        return nidx


class SegmentToRgb:

    def __init__(self, keys, num_color, normaliz=True):
        self.keys = keys
        self.spectrum = __class__.generate_spectrum_colors(num_color)
        self.normaliz = normaliz

    def __call__(self, **pack: dict) -> dict:
        for key in self.keys:
            index = DictTool.getattr(pack, key)
            assert index.ndim == 2 and index.dtype == pt.uint8
            color = self.spectrum[index.long()].permute(2, 0, 1)  # (c=3,h,w)
            # assert color.shape == (3, 128, 128)
            # import cv2
            # cv2.imshow("", color.permute(1, 2, 0).numpy())
            # cv2.waitKey(0)
            if self.normaliz:
                color = (color - 127.5) / 127.5
            DictTool.setattr(pack, key, color)
        return pack

    @staticmethod
    def generate_spectrum_colors(num_color):
        spectrum = []
        for i in range(num_color):
            hue = i / float(num_color)
            rgb = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
            spectrum.append([int(255 * c) for c in rgb])
        return pt.as_tensor(spectrum, dtype=pt.uint8)  # (n,c=3)
