from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import einops
from jaxtyping import Float

from . import min_dinov2


def better_resize(imgs: torch.Tensor, image_size: int) -> torch.Tensor:
    ss = imgs.shape
    assert ss[-3] == 3

    H, W = ss[-2:]

    if len(ss) == 3:
        imgs = imgs.unsqueeze(0)

    side = min(H, W)
    factor = side // image_size

    imgs = TF.center_crop(imgs, [side, side])
    if factor > 1:
        imgs = F.avg_pool2d(imgs, factor)
    imgs = F.interpolate(imgs, [image_size, image_size], mode="bilinear")

    if len(ss) == 3:
        imgs = imgs[0]
    return imgs


class DinoV2Reg(nn.Module):
    def __init__(
        self,
        model_size: int = 224,
        model_version: str = "dinov2_vitl14_reg",
        gradient_last_blocks: None | int = None,
        reshape: bool = True,
        out: str = "both",
        requires_grad: bool = False,
    ) -> None:
        super().__init__()
        self.model_version = model_version
        self.model_size = model_size
        self.reshape = reshape
        self.out = out

        # default is force_reload=False
        self.model: nn.Module = torch.hub.load("facebookresearch/dinov2", model_version)

        self.model.forward = partial(self.model.get_intermediate_layers, return_class_token=True, reshape=reshape)

        if requires_grad:
            self.model.requires_grad_(True)
            self.model.train()
        else:
            self.model.requires_grad_(False)
            self.model.eval()

        self.model.compile()

        self.gradient_last_blocks = gradient_last_blocks
        if gradient_last_blocks is not None and gradient_last_blocks > 0:
            for l in self.model.blocks[-gradient_last_blocks:]:
                l.requires_grad_(True)
                l.train()

    def forward(
        self, imgs: Float[torch.Tensor, "B C H W"]
    ) -> tuple[Float[torch.Tensor, "B D h' w'"], Float[torch.Tensor, "B D"]]:
        assert imgs.min() >= -1.0
        assert imgs.max() <= 1.0
        assert len(imgs.shape) == 4

        imgs = better_resize(imgs, self.model_size)

        imgs = (imgs + 1.0) / 2.0
        # copied from transformers preprocessor
        imgs = TF.normalize(imgs, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

        features, cls = self.model(imgs)[0]

        if self.out == "features":
            return features
        elif self.out == "class":
            return cls
        else:
            return features, cls


class MinDinoV2Reg(nn.Module):  # Same as DinoV2Reg, but based on Simo's minimal implementation
    def __init__(
        self,
        model_size: int = 224,
        model_version: str = "dinov2_vitl14_reg",
        gradient_last_blocks: None | int = None,
        reshape: bool = True,
        out: str = "both",
        requires_grad: bool = False,
    ) -> None:
        super().__init__()
        self.model_version = model_version
        self.model_size = model_size
        self.reshape = reshape
        self.out = out

        # default is force_reload=False
        original_model: nn.Module = torch.hub.load("facebookresearch/dinov2", model_version)
        self.model = {
            "dinov2_vits14": min_dinov2.vit_small,
            "dinov2_vitb14": min_dinov2.vit_base,
            "dinov2_vitl14": min_dinov2.vit_large,
            "dinov2_vitg14": min_dinov2.vit_giant2,
            "dinov2_vits14_reg": min_dinov2.vit_small_reg,
            "dinov2_vitb14_reg": min_dinov2.vit_base_reg,
            "dinov2_vitl14_reg": min_dinov2.vit_large_reg,
            "dinov2_vitg14_reg": min_dinov2.vit_giant2_reg,
        }[model_version]()
        self.model.load_state_dict(original_model.state_dict(), strict=True)

        if requires_grad:
            self.model.requires_grad_(True)
            self.model.train()
        else:
            self.model.requires_grad_(False)
            self.model.eval()

        # self.model.compile()
        self.model.forward_features = torch.compile(self.model.forward_features, fullgraph=True, dynamic=False)

        self.gradient_last_blocks = gradient_last_blocks
        if gradient_last_blocks is not None and gradient_last_blocks > 0:
            for l in self.model.blocks[-gradient_last_blocks:]:
                l.requires_grad_(True)
                l.train()

    def forward(
        self, imgs: Float[torch.Tensor, "B C H W"]
    ) -> tuple[Float[torch.Tensor, "B D h' w'"], Float[torch.Tensor, "B D"]]:
        assert imgs.min() >= -1.0
        assert imgs.max() <= 1.0
        assert len(imgs.shape) == 4

        imgs = better_resize(imgs, self.model_size)

        imgs = (imgs + 1.0) / 2.0
        # copied from transformers preprocessor
        imgs = TF.normalize(imgs, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

        # features, cls = self.model(imgs)[0]
        torch.compiler.cudagraph_mark_step_begin()
        d = self.model.forward_features(imgs, masks=None)

        if self.out == "class":
            return d["x_norm_clstoken"]
        else:
            features = einops.rearrange(
                d["x_norm_patchtokens"], "b (h w) d -> b d h w", h=self.model_size // 14, w=self.model_size // 14
            )
            if self.out == "features":
                return features
            else:
                return features, d["x_norm_clstoken"]
