import torch, timm
import torchvision.transforms as T

FEATURE_DIM_MAP = {
    ("vit", "tiny"): 192,
    ("vit", "small"): 384,
    ("vit", "base"): 768,
    ("vit", "large"): 1024,
    ("vit", "huge"): 1280,

    ("dinov2", "small"): 384,
    ("dinov2", "base"): 768,
    ("dinov2", "large"): 1024,
    ("dinov2", "giant"): 1536,
}

class EncoderManager:
    """
    Encoder Manager (ViT / DINOv2)
    cfg["encoder"] =
    {
        "type": "vit" | "dinov2",
        "size": "tiny/small/base/large/huge/giant",
        "patch": 16 | 14 | ...
        "from_pretrained": True/False,
        "tune": False,
        "precompute": True,
        "use_cache": True,
        "precompute_batch_size": 256,
    }
    """
    def __init__(self, cfg, device):
        self.cfg  = cfg
        self.device = device
        ecfg = cfg.get("encoder", {})

        self.encoder_type    = ecfg.get("type", "vit").lower()
        self.model_size      = ecfg.get("size", "base").lower()
        self.patch_size      = int(ecfg.get("patch", 16))
        self.from_pretrained = bool(ecfg.get("from_pretrained", True))
        self.tune            = bool(ecfg.get("tune", False))
        self.precompute      = bool(ecfg.get("precompute", True))
        self.save_cache      = bool(ecfg.get("use_cache", True))
        self.precompute_batch_size = int(ecfg.get("precompute_batch_size", 256))

        model_name = self._resolve_model_name()
        self.model = timm.create_model(
            model_name,
            pretrained=self.from_pretrained,
            num_classes=0,
        ).to(self.device)

        if not self.tune:
            for p in self.model.parameters():
                p.requires_grad = False

        self.feature_dim = FEATURE_DIM_MAP.get((self.encoder_type, self.model_size), None)
        actual_dim = self._infer_feature_dim()
        if self.feature_dim is None or self.feature_dim != actual_dim:
            self.feature_dim = actual_dim

    def _resolve_model_name(self):
        et, sz, pt = self.encoder_type, self.model_size, self.patch_size

        if et == "vit":
            return f"vit_{sz}_patch{pt}_224"

        if et == "dinov2":
            primary  = f"vit_{sz}_patch{pt}_dinov2"
            fallback = f"vit_{sz}_patch{pt}_reg4_dinov2"

            available = set(timm.list_models("*dinov2*", pretrained=True))
            if primary in available:
                return primary
            if fallback in available:
                return fallback

            candidates = sorted([
                m for m in available
                if (f"vit_{sz}_patch{pt}" in m) and ("dinov2" in m)
            ])
            if candidates:
                return sorted(candidates, key=len)[0]

            raise RuntimeError(
                f"Unknown DINOv2 model for size={sz}, patch={pt}. "
                f"Tried: {primary}, {fallback}. "
                f"Available (pretrained) examples: {sorted(list(available))[:10]} ..."
            )

        raise ValueError(f"Unknown encoder_type={et}")

    @torch.no_grad()
    def _infer_feature_dim(self):
        self.model.eval()
        if hasattr(self.model, "default_cfg") and "input_size" in self.model.default_cfg:
            _, h, w = self.model.default_cfg["input_size"]
        else:
            h = w = 224
        x = torch.zeros(1, 3, h, w, device=self.device)
        y = self.model(x)
        return int(y.shape[-1])

    def image_transform(self):
        if hasattr(self.model, "default_cfg") and "input_size" in self.model.default_cfg:
            _, h, w = self.model.default_cfg["input_size"]
        else:
            h = w = 224
        return T.Compose([
            T.Resize(h, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
            T.CenterCrop(h),
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

    def can_precompute(self):
        return (not self.tune) and self.precompute

    @torch.no_grad()
    def encode(self, x):
        self.model.eval()
        return self.model(x.to(self.device))

    def get_info(self):
        return {
            "feature_dim": self.feature_dim,
            "type": self.encoder_type,
            "size": self.model_size,
            "patch": self.patch_size,
            "from_pretrained": self.from_pretrained,
            "tune": self.tune,
            "precompute": self.precompute,
            "use_cache": self.save_cache,
        }
