import random
from collections.abc import Mapping, Sequence
import numpy as np
import torch
from torch.utils.data.dataloader import default_collate
import torch.nn.functional as F


def collate_fn(batch):
    """
    collate function for point cloud which support dict and list,
    'coord' is necessary to determine 'offset'
    """
    if not isinstance(batch, Sequence):
        raise TypeError(f"{batch.dtype} is not supported.")

    if isinstance(batch[0], torch.Tensor):
        # print("batch[0]", [i.shape for i in batch])
        return torch.cat(list(batch))
    elif isinstance(batch[0], str):
        # str is also a kind of Sequence, judgement should before Sequence
        return list(batch)
    elif isinstance(batch[0], Sequence):
        for data in batch:
            data.append(torch.tensor([data[0].shape[0]]))
        batch = [collate_fn(samples) for samples in zip(*batch)]
        batch[-1] = torch.cumsum(batch[-1], dim=0).int()
        return batch
    elif isinstance(batch[0], Mapping):
        if "imgs_global_crops" in batch[0].keys():
            sample_list = []
            for data in batch:
                img_data = dict(
                    imgs_global_crops=data["imgs_global_crops"],
                    imgs_local_crops=data["imgs_local_crops"],
                    imgs_global_crops_teacher=data["imgs_global_crops_teacher"],
                    imgs_offsets=data["imgs_offsets"],
                    imgs_masks_list=data["imgs_masks_list"],
                    imgs_masks_upperbound=data["imgs_masks_upperbound"],
                    img_size=data["img_size"],
                )
                sample_list.append(img_data)
                data.pop("imgs_global_crops")
                data.pop("imgs_local_crops")
                data.pop("imgs_global_crops_teacher")
                data.pop("imgs_offsets")
                data.pop("imgs_masks_list")
                data.pop("imgs_masks_upperbound")
                data.pop("img_size")
            imgs_batch = collate_data_and_cast(sample_list, dtype=torch.float32)
        else:
            imgs_batch = None

        if "img_num" in batch[0].keys():
            max_img_num = max([d["img_num"] for d in batch])
        else:
            max_img_num = 4
        batch = {
            key: (
                (
                    collate_fn([d[key] for d in batch])
                    if "offset" not in key
                    # offset -> bincount -> concat bincount-> concat offset
                    else torch.cumsum(
                        collate_fn(
                            [d[key].diff(prepend=torch.tensor([0])) for d in batch]
                        ),
                        dim=0,
                    )
                )
                if "mask_index" not in key
                else collate_fn(
                    [
                        F.pad(
                            d[key].permute(0, 2, 1),
                            (0, max_img_num - d[key].shape[1]),
                            value=-1,
                        ).permute(0, 2, 1)
                        for d in batch
                    ]
                )
            )
            for key in batch[0]
            # ) if key != "global_mask_index" else collate_fn([F.pad(d[key].permute(0, 2, 1),(0,max_img_num-d[key].shape[1]),value=-1).permute(0, 2, 1) for d in batch]) for key in batch[0]
        }
        if imgs_batch is not None:
            batch.update(imgs_batch)
        return batch
    else:
        return default_collate(batch)


def point_collate_fn(batch, mix_prob=0):
    assert isinstance(
        batch[0], Mapping
    )  # currently, only support input_dict, rather than input_list
    batch = collate_fn(batch)
    # print("batch", batch["global_mask_index"].shape)
    # batch = mask_index_align(batch)
    if random.random() < mix_prob:
        if "instance" in batch.keys():
            offset = batch["offset"]
            start = 0
            num_instance = 0
            for i in range(len(offset)):
                if i % 2 == 0:
                    num_instance = max(batch["instance"][start : offset[i]])
                if i % 2 != 0:
                    mask = batch["instance"][start : offset[i]] != -1
                    batch["instance"][start : offset[i]] += num_instance * mask
                start = offset[i]
        if "mask_index" in batch.keys():
            offset = batch["offset"]
            start = 0
            N, v, n = batch["mask_index"].shape
            v2 = v * 2
            batch_mask_index_mix = -torch.ones((N, v2, n))
            for i in range(len(offset)):
                if i % 2 == 0:
                    batch_mask_index_mix[start : offset[i], 0:v] = batch["mask_index"][
                        start : offset[i], 0:v
                    ]
                if i % 2 != 0:
                    batch_mask_index_mix[start : offset[i], v:] = batch["mask_index"][
                        start : offset[i], 0:v
                    ]
                start = offset[i]
            if len(offset) % 2 == 0:
                pass
            else:
                start = 0 if len(offset) == 1 else offset[-2]
                batch_mask_index_mix[start:N, -v:] = batch["mask_index"][start:N, -v:]
            batch["mask_index"] = batch_mask_index_mix
        if "global_mask_index" in batch.keys():
            offset = batch["offset"]
            start = 0
            N, v, n = batch["global_mask_index"].shape
            v2 = v * 2
            batch_mask_index_mix = -torch.ones((N, v2, n))
            for i in range(len(offset)):
                if i % 2 == 0:
                    batch_mask_index_mix[start : offset[i], 0:v] = batch[
                        "global_mask_index"
                    ][start : offset[i], 0:v]
                if i % 2 != 0:
                    batch_mask_index_mix[start : offset[i], v:] = batch[
                        "global_mask_index"
                    ][start : offset[i], 0:v]
                start = offset[i]
            if len(offset) % 2 == 0:
                pass
            else:
                start = 0 if len(offset) == 1 else offset[-2]
                batch_mask_index_mix[start:N, -v:] = batch["global_mask_index"][
                    start:N, -v:
                ]
            batch["global_mask_index"] = batch_mask_index_mix
        if "dino_mask_index" in batch.keys():
            offset = batch["offset"]
            start = 0
            N, v, n = batch["dino_mask_index"].shape
            v2 = v * 2
            batch_mask_index_mix = -torch.ones((N, v2, n))
            for i in range(len(offset)):
                if i % 2 == 0:
                    batch_mask_index_mix[start : offset[i], 0:v] = batch[
                        "dino_mask_index"
                    ][start : offset[i], 0:v]
                if i % 2 != 0:
                    batch_mask_index_mix[start : offset[i], v:] = batch[
                        "dino_mask_index"
                    ][start : offset[i], 0:v]
                start = offset[i]
            if len(offset) % 2 == 0:
                pass
            else:
                start = 0 if len(offset) == 1 else offset[-2]
                batch_mask_index_mix[start:N, -v:] = batch["dino_mask_index"][
                    start:N, -v:
                ]
            batch["dino_mask_index"] = batch_mask_index_mix
        if "imgs" in batch.keys():
            offset = batch["offset"]
            b = len(offset)
            start = 0
            vb, c, h, w = batch["imgs"].shape
            v = vb // b
            if b % 2 == 0:
                pass
            else:
                batch["imgs"] = torch.cat([batch["imgs"], batch["imgs"][-v:]], dim=0)
        if "dino_feature" in batch.keys():
            offset = batch["offset"]
            b = len(offset)
            start = 0
            vb, s, c = batch["dino_feature"].shape
            v = vb // b
            if b % 2 == 0:
                pass
            else:
                batch["dino_feature"] = torch.cat(
                    [batch["dino_feature"], batch["dino_feature"][-v:]], dim=0
                )
        if "corresponding" in batch.keys():
            offset = batch["offset"]
            start = 0
            N, v, n = batch["corresponding"].shape
            for i in range(len(offset)):
                if i % 2 == 0:
                    minus_one = batch["corresponding"][
                        batch["corresponding"][:, :, 0] < 0
                    ]
                    if minus_one.size(0) > 0:
                        num_add = -minus_one[0, 0]
                    else:
                        num_add = max(
                            batch["corresponding"][start : offset[i], :, 0].reshape(-1)
                        )
                if i % 2 != 0:
                    batch_valid = batch["corresponding"][start : offset[i], :, 0] >= 0
                    batch["corresponding"][start : offset[i], :, 0] += (
                        num_add * batch_valid
                    )
                start = offset[i]
        if "offset" in batch.keys():
            batch["offset"] = torch.cat(
                [batch["offset"][1:-1:2], batch["offset"][-1].unsqueeze(0)], dim=0
            )
    return batch


def gaussian_kernel(dist2: np.array, a: float = 1, c: float = 5):
    return a * np.exp(-dist2 / (2 * c**2))


def collate_data_and_cast(samples_list, dtype):
    # dtype = torch.half  # TODO: Remove

    n_global_crops = len(samples_list[0]["imgs_global_crops"])
    n_local_crops = len(samples_list[0]["imgs_local_crops"])

    collated_global_crops = torch.cat(
        [
            s["imgs_global_crops"][i]
            for i in range(n_global_crops)
            for s in samples_list
        ],
        dim=0,
    )

    collated_local_crops = torch.cat(
        [s["imgs_local_crops"][i] for i in range(n_local_crops) for s in samples_list],
        dim=0,
    )

    collated_masks_list = torch.cat([s["imgs_masks_list"] for s in samples_list], dim=0)

    upperbound = torch.stack([s["imgs_masks_upperbound"] for s in samples_list])
    upperbound = upperbound.sum(0)
    collated_masks = collated_masks_list.flatten(1)
    mask_indices_list = collated_masks.flatten().nonzero().flatten()

    masks_weight = (
        (1 / collated_masks.sum(-1).clamp(min=1.0))
        .unsqueeze(-1)
        .expand_as(collated_masks)[collated_masks]
    )

    return {
        "imgs_global_crops": collated_global_crops.to(dtype),
        "imgs_local_crops": collated_local_crops.to(dtype),
        "imgs_masks": collated_masks,
        "imgs_mask_indices_list": mask_indices_list,
        "imgs_masks_weight": masks_weight,
        "imgs_masks_upperbound": upperbound,
        "imgs_n_masked_patches": torch.full(
            (1,), fill_value=mask_indices_list.shape[0], dtype=torch.long
        ),
        "img_size": samples_list[0]["img_size"],
    }
