##################################################
# Multi-crop related code re-used from DINO
# https://github.com/facebookresearch/dino
##################################################

import random

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import ImageFilter, ImageOps


class MultiCropAugmentation(object):
    def __init__(self, global_number, global_scale, local_number, local_scale):
        assert (global_number > 0) or (local_number > 0)
        self.global_number = global_number
        self.local_number = local_number


        normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                                         std=(0.26862954, 0.26130258, 0.27577711))  # for CLIP

        # self.global_tfm = transforms.Compose(
        #     [
        #         transforms.RandomResizedCrop(
        #             224,
        #             scale=global_scale,
        #             interpolation=transforms.InterpolationMode.BICUBIC,
        #         ),
        #         normalize,
        #     ]
        # )

        self.global_tfm = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize
        ])


        # transformation for the local small crops
        self.local_tfm = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    96,
                    scale=local_scale,
                    interpolation=transforms.InterpolationMode.BICUBIC,
                ),
                normalize,
            ]
        )

    def __repr__(self) -> str:
        return (
            "global_number={}, local_number={}, \nglobal_tfm={}\nlocal_tfm={}".format(
                self.global_number, self.local_number, self.global_tfm, self.local_tfm
            )
        )

    def __call__(self, image):
        crops = []

        for _ in range(self.global_number):
            crops.append(self.global_tfm(image))

        # for _ in range(self.local_number):
        #     crops.append(self.local_tfm(image))

        return crops



class MultiCropWrapper(nn.Module):
    """
    Perform forward pass separately on each resolution input.
    The inputs corresponding to a single resolution are clubbed and single
    forward is run on the same resolution inputs. Hence we do several
    forward passes = number of different resolutions used. We then
    concatenate all the output features and run the head forward on these
    concatenated features.
    """

    def __init__(self, net, head=None):
        super().__init__()
        # disable layers dedicated to ImageNet labels classification
        self.net = encoder

    def forward(self, x):
        # convert to list
        if not isinstance(x, list):
            x = [x]
        idx_crops = torch.cumsum(
            torch.unique_consecutive(
                torch.tensor([inp.shape[-1] for inp in x]),
                return_counts=True,
            )[1],
            0,
        )
        start_idx, output = 0, torch.empty(0).to(x[0].device)
        for end_idx in idx_crops:
            _out = self.net.get_image_features(pixel_values = torch.cat(x[start_idx:end_idx])).float()
            # accumulate outputs
            output = torch.cat((output, _out))
            start_idx = end_idx

        return output




    
