import torch
from PIL import Image
import math
import itertools
from functools import partial
import mmcv
from mmcv.runner import load_checkpoint
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as ff
from dinov2.eval.depth.models import build_depther


class DINO(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.cache = {}

        # dino
        self.dino_backbone_name = "dinov2_vits14"
        self.dino_backbone_model = torch.hub.load(
            "/data/Hypothesis/theorem/3dgen/dinov2",
            "dinov2_vits14",
            source="local",
            pretrained=False,
        )
        self.dino_backbone_model.load_state_dict(
            torch.load(
                "/data/models/dinov2/dinov2_vits14_pretrain.pth", map_location="cuda"
            )
        )
        self.dino_backbone_model.eval()
        self.dino_backbone_model.cuda()
        cfg = mmcv.Config.fromfile(
            "/data/Hypothesis/theorem/3dgen/projection-conditioned-point-cloud-diffusion/experiments/dinov2_vits12_nyu_dpt_config.py"
        )

        self.dino_model = _create_depther(cfg, backbone_model=self.dino_backbone_model)

        load_checkpoint(
            self.dino_model,
            "/data/models/dinov2/dinov2_vits14_nyu_dpt_head.pth",
            map_location="cuda",
        )
        self.dino_model.eval()
        self.dino_model.cuda()

    def hash_tensor(self, tensor):
        return hash(tuple(tensor.reshape(-1).tolist()))

    def forward(self, images):
        batch_features = []
        for idx, image in enumerate(images):
            image_hash = self.hash_tensor(image)

            # Cache check and update
            if image_hash in self.cache:
                features = self.cache[image_hash]
            else:
                if isinstance(image, Image.Image):
                    image = ff.to_tensor(image).unsqueeze(0).cuda()

                with torch.no_grad(), torch.cuda.amp.autocast(), torch.inference_mode():
                    features = self.dino_model.whole_inference(
                        image.unsqueeze(0), img_meta=None, rescale=True
                    )

                self.cache[image_hash] = features.cpu()
            batch_features.append(features.cuda())
        # Stack all features into a single tensor to maintain batch dimension
        batch_features = torch.cat(batch_features, dim=0)
        return batch_features.cuda()


class _CenterPadding(torch.nn.Module):
    def __init__(self, multiple):
        super().__init__()
        self.multiple = multiple

    def _get_pad(self, size):
        new_size = math.ceil(size / self.multiple) * self.multiple
        pad_size = new_size - size
        pad_size_left = pad_size // 2
        pad_size_right = pad_size - pad_size_left
        return pad_size_left, pad_size_right

    @torch.inference_mode()
    def forward(self, x):
        pads = list(
            itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])
        )
        output = F.pad(x, pads)
        return output


def _create_depther(cfg, backbone_model):
    train_cfg = cfg.get("train_cfg")
    test_cfg = cfg.get("test_cfg")
    depther = build_depther(cfg.model, train_cfg=train_cfg, test_cfg=test_cfg)

    depther.backbone.forward = partial(
        backbone_model.get_intermediate_layers,
        n=cfg.model.backbone.out_indices,
        reshape=True,
        return_class_token=cfg.model.backbone.output_cls_token,
        norm=cfg.model.backbone.final_norm,
    )

    if hasattr(backbone_model, "patch_size"):
        depther.backbone.register_forward_pre_hook(
            lambda _, x: _CenterPadding(backbone_model.patch_size)(x[0])
        )

    return depther
