import math
from functools import reduce
from weakref import WeakValueDictionary

import torch
import transformers
from torch import nn
from torchvision.transforms import Normalize

from dae.utils.torch_utils import Ref, freeze_model, mark_initialized

CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)


_model_cache = WeakValueDictionary()


def instanciate_pretrained_model(model_name, cache_dir=None, allow_ref=True):
    if model_name in _model_cache and allow_ref:
        return Ref(_model_cache[model_name])

    model = None
    if "siglip" in model_name:
        model = transformers.SiglipVisionModel.from_pretrained(model_name, cache_dir=cache_dir)
    else:
        model = transformers.AutoModel.from_pretrained(model_name, cache_dir=cache_dir)

    _model_cache[model_name] = model
    mark_initialized(model)
    return model


class PreTrainedEncoder(nn.Module):
    MODEL_REPO_MAP = {
        "dinov2_small": "facebook/dinov2-small",
        "dinov2_base": "facebook/dinov2-base",
        "dinov2_large": "facebook/dinov2-large",
        "dinov2_giant": "facebook/dinov2-giant",
        "siglip2_base": "google/siglip2-base-patch16-256",
        "siglip2_large": "google/siglip2-large-patch16-256",
        "siglip2_base_224": "google/siglip2-base-patch16-224",
        "siglip2_large_224": "google/siglip2-large-patch16-224",
        "siglip2_so400m_512": "google/siglip2-so400m-patch16-512",
    }
    MODEL_TYPE_MAP = {
        "dinov2_small": "dinov2",
        "dinov2_base": "dinov2",
        "dinov2_large": "dinov2",
        "dinov2_giant": "dinov2",
        "siglip2_base": "siglip",
        "siglip2_large": "siglip",
        "siglip2_base_224": "siglip_224",
        "siglip2_large_224": "siglip_224",
        "siglip2_so400m_512": "siglip_512",
    }
    OUT_DIMS = {
        "dinov2_small": 384,
        "dinov2_base": 768,
        "dinov2_large": 1024,
        "dinov2_giant": 1536,
        "siglip2_base": 768,
        "siglip2_large": 1024,
        "siglip2_base_224": 768,
        "siglip2_large_224": 1024,
        "siglip2_so400m_512": 1152,
    }
    HAS_CLS = ["mocov3", "dinov2", "siglip"]
    PATCH_SIZE = {
        "dinov2": 14,
        "jepa": 14,
        "dinov1": 16,
        "siglip": 16,
        "siglip_224": 16,
        "siglip_512": 16,
        "mae": 16,
        "mocov2": 16,
    }
    BASE_RES = {
        "dinov2": 224,
        "dinov1": 224,
        "siglip": 256,
        "siglip_224": 224,
        "siglip_512": 512,
    }

    IMG_NORM = {
        "dinov2": "imagenet",
        "dinov1": "imagenet",
        "clip": "clip",
        "siglip_224": False,
        "siglip_512": False,
    }

    def __init__(self, model_name, freeze=True, drop_cls=False, cache_dir=None):
        super().__init__()

        self.model_type = self.MODEL_TYPE_MAP[model_name]
        self.out_dim = self.OUT_DIMS[model_name]
        self.base_patch_size = self.PATCH_SIZE[self.model_type]
        self.drop_cls = drop_cls
        self.base_res = self.BASE_RES.get(self.model_type, None)
        self.img_norm = self.IMG_NORM[self.model_type]

        self.model = instanciate_pretrained_model(self.MODEL_REPO_MAP[model_name], cache_dir=cache_dir, allow_ref=False)
        mark_initialized(self.model)

        self.freeze = freeze
        if freeze:
            freeze_model(self)
        elif self.model_type == "siglip":
            freeze_model(self.model.vision_model.head)

    def rescale_and_process_image(self, x, target_n_tokens=None, ensure_base_res=False):
        # assert -1 <= x.min() and x.max() <= 1

        if self.img_norm == "clip":
            x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)((x + 1) / 2)
        elif self.img_norm == "imagenet":
            x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)((x + 1) / 2)
        elif self.img_norm is False:
            pass
        else:
            raise ValueError(f"Unknown image normalization for model {self.model_type}: {self.img_norm}")

        if isinstance(target_n_tokens, (tuple, list)):
            target_n_tokens = reduce(lambda x, y: x * y, target_n_tokens)

        if ensure_base_res:
            base_res = self.base_res
            if x.shape[-1] != base_res:
                x = torch.nn.functional.interpolate(x, (base_res, base_res), mode="bicubic")
        elif target_n_tokens is not None:
            model_patch_size = self.PATCH_SIZE[self.model_type]
            _, _, h, w = x.shape
            r = math.sqrt(target_n_tokens * model_patch_size**2 / (h * w))
            H, W = round(h * r), round(w * r)
            x = torch.nn.functional.interpolate(x, (H, W), mode="bicubic")
        return x

    def forward(self, x, target_n_tokens=None, ensure_base_res=False):
        if self.freeze and self.training:
            self.eval()
        x = self.rescale_and_process_image(x, target_n_tokens, ensure_base_res=ensure_base_res)
        z = self.model(x)
        if self.model_type == "siglip":
            z = torch.cat([z.pooler_output.unsqueeze(1), z.last_hidden_state], dim=1)  # Add 'CLS' token
        else:
            z = z.last_hidden_state
        if self.drop_cls and self.model_type in self.HAS_CLS:
            z = z[:, 1:]
        return z
