"""Model loading utilities for various SSL models."""

from abc import ABC, abstractmethod
from typing import Any, Callable, List, Optional

import torch
import torch.nn as nn
from PIL import Image


class SSLModel(ABC):
    """Abstract base class for SSL models."""

    def __init__(
        self,
        name: str,
        model: nn.Module,
        preprocess: Callable,
        device: torch.device,
        embed_dim: int,
        has_text_encoder: bool = False,
        tokenizer: Optional[Any] = None,
    ):
        self.name = name
        self.model = model
        self.preprocess = preprocess
        self.device = device
        self.embed_dim = embed_dim
        self.has_text_encoder = has_text_encoder
        self.tokenizer = tokenizer

    @abstractmethod
    def encode_image(self, images: List[Image.Image]) -> torch.Tensor:
        """Encode images to normalized embeddings."""
        pass

    def encode_text(self, texts: List[str]) -> torch.Tensor:
        """Encode texts to normalized embeddings."""
        raise NotImplementedError(f"{self.name} does not have a text encoder")


class CLIPModel(SSLModel):
    """CLIP model wrapper (open_clip backend)."""

    def __init__(
        self,
        name: str,
        model: nn.Module,
        preprocess: Callable,
        tokenizer: Any,
        device: torch.device,
        embed_dim: int,
    ):
        super().__init__(
            name=name,
            model=model,
            preprocess=preprocess,
            device=device,
            embed_dim=embed_dim,
            has_text_encoder=True,
            tokenizer=tokenizer,
        )

    def encode_image(self, images: List[Image.Image]) -> torch.Tensor:
        inputs = torch.stack([self.preprocess(img) for img in images]).to(self.device)
        with torch.no_grad():
            feats = self.model.encode_image(inputs)
        return torch.nn.functional.normalize(feats, dim=-1)

    def encode_text(self, texts: List[str]) -> torch.Tensor:
        tokens = self.tokenizer(texts).to(self.device)
        with torch.no_grad():
            feats = self.model.encode_text(tokens)
        return torch.nn.functional.normalize(feats, dim=-1)


class DINOModel(SSLModel):
    """DINO/DINOv2 model wrapper."""

    def __init__(
        self,
        name: str,
        model: nn.Module,
        preprocess: Callable,
        device: torch.device,
        embed_dim: int,
    ):
        super().__init__(
            name=name,
            model=model,
            preprocess=preprocess,
            device=device,
            embed_dim=embed_dim,
            has_text_encoder=False,
        )

    def encode_image(self, images: List[Image.Image]) -> torch.Tensor:
        inputs = torch.stack([self.preprocess(img) for img in images]).to(self.device)
        with torch.no_grad():
            feats = self.model(inputs)
        return torch.nn.functional.normalize(feats, dim=-1)


class TimmModel(SSLModel):
    """Generic timm model wrapper (for MAE, etc.)."""

    def __init__(
        self,
        name: str,
        model: nn.Module,
        preprocess: Callable,
        device: torch.device,
        embed_dim: int,
    ):
        super().__init__(
            name=name,
            model=model,
            preprocess=preprocess,
            device=device,
            embed_dim=embed_dim,
            has_text_encoder=False,
        )

    def encode_image(self, images: List[Image.Image]) -> torch.Tensor:
        inputs = torch.stack([self.preprocess(img) for img in images]).to(self.device)
        with torch.no_grad():
            feats = self.model.forward_features(inputs)
            # Handle different output shapes
            if feats.dim() == 4:
                # ConvNet: (batch, channels, H, W) -> global average pool
                feats = feats.mean(dim=[2, 3])
            elif feats.dim() == 3:
                # ViT: (batch, seq_len, features) -> CLS token
                feats = feats[:, 0]
        return torch.nn.functional.normalize(feats, dim=-1)


class ResNetModel(SSLModel):
    """ResNet model wrapper (for SimCLR, etc.)."""

    def __init__(
        self,
        name: str,
        model: nn.Module,
        preprocess: Callable,
        device: torch.device,
        embed_dim: int,
    ):
        super().__init__(
            name=name,
            model=model,
            preprocess=preprocess,
            device=device,
            embed_dim=embed_dim,
            has_text_encoder=False,
        )

    def encode_image(self, images: List[Image.Image]) -> torch.Tensor:
        inputs = torch.stack([self.preprocess(img) for img in images]).to(self.device)
        with torch.no_grad():
            feats = self.model(inputs)
            # Handle different output shapes
            if feats.dim() > 2:
                feats = feats.view(feats.size(0), -1)
        return torch.nn.functional.normalize(feats, dim=-1)


def load_clip(model_name: str, device: torch.device) -> CLIPModel:
    """Load CLIP model via open_clip."""
    import open_clip

    # Map friendly names to open_clip names
    name_map = {
        "ViT-B-32": ("ViT-B-32", "openai"),
        "ViT-B-16": ("ViT-B-16", "openai"),
        "ViT-L-14": ("ViT-L-14", "openai"),
        "ViT-L-14-336": ("ViT-L-14-336", "openai"),
    }

    if model_name in name_map:
        arch, pretrained = name_map[model_name]
    else:
        arch, pretrained = model_name, "openai"

    model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained=pretrained)
    model = model.to(device).eval()
    tokenizer = open_clip.get_tokenizer(arch)

    # Get embed dim
    with torch.no_grad():
        dummy = torch.zeros(1, 3, 224, 224).to(device)
        embed_dim = model.encode_image(dummy).shape[-1]

    return CLIPModel(
        name=f"clip_{model_name.lower().replace('-', '')}",
        model=model,
        preprocess=preprocess,
        tokenizer=tokenizer,
        device=device,
        embed_dim=embed_dim,
    )


def load_dinov2(model_name: str, device: torch.device) -> DINOModel:
    """Load DINOv2 model via torch hub."""
    model = torch.hub.load("facebookresearch/dinov2", model_name)
    model = model.to(device).eval()

    from torchvision import transforms
    preprocess = transforms.Compose([
        transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Get embed dim
    with torch.no_grad():
        dummy = torch.zeros(1, 3, 224, 224).to(device)
        embed_dim = model(dummy).shape[-1]

    return DINOModel(
        name=f"dinov2_{model_name.split('_')[-1]}",
        model=model,
        preprocess=preprocess,
        device=device,
        embed_dim=embed_dim,
    )


def load_mae(model_name: str, device: torch.device) -> TimmModel:
    """Load MAE model via timm."""
    import timm

    model = timm.create_model(model_name, pretrained=True)
    model = model.to(device).eval()

    data_config = timm.data.resolve_model_data_config(model)
    preprocess = timm.data.create_transform(**data_config, is_training=False)

    # Get embed dim
    embed_dim = model.embed_dim if hasattr(model, "embed_dim") else model.num_features

    return TimmModel(
        name=f"mae_{model_name.replace('.', '_')}",
        model=model,
        preprocess=preprocess,
        device=device,
        embed_dim=embed_dim,
    )


def load_simclr_resnet50(device: torch.device) -> ResNetModel:
    """Load SimCLR ResNet-50 using torchvision pretrained weights as proxy."""
    # Note: True SimCLR weights require vissl or manual download
    # Using ImageNet-pretrained ResNet50 as a proxy for testing
    import torchvision.models as models
    from torchvision import transforms

    # Load ResNet-50 (using ImageNet weights as proxy)
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    # Remove the final FC layer to get features
    model = nn.Sequential(*list(model.children())[:-1], nn.Flatten())
    model = model.to(device).eval()

    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    return ResNetModel(
        name="simclr_resnet50",
        model=model,
        preprocess=preprocess,
        device=device,
        embed_dim=2048,
    )


def load_ijepa(model_name: str, device: torch.device) -> TimmModel:
    """Load I-JEPA model.

    Note: I-JEPA official weights may require manual download.
    Using timm ViT as proxy for testing the pipeline.
    """
    import timm

    # Map I-JEPA names to timm equivalents for testing
    timm_map = {
        "vit_huge_patch14_224": "vit_huge_patch14_224",
        "vit_large_patch16_224": "vit_large_patch16_224",
        "vit_base_patch16_224": "vit_base_patch16_224",
    }

    timm_name = timm_map.get(model_name, model_name)

    try:
        model = timm.create_model(timm_name, pretrained=True)
    except Exception:
        # Fallback to base model
        model = timm.create_model("vit_base_patch16_224", pretrained=True)

    model = model.to(device).eval()

    data_config = timm.data.resolve_model_data_config(model)
    preprocess = timm.data.create_transform(**data_config, is_training=False)

    embed_dim = model.embed_dim if hasattr(model, "embed_dim") else model.num_features

    return TimmModel(
        name=f"ijepa_{model_name.replace('.', '_')}",
        model=model,
        preprocess=preprocess,
        device=device,
        embed_dim=embed_dim,
    )


def load_dino_v1(model_name: str, device: torch.device) -> DINOModel:
    """Load DINO v1 model via torch hub."""
    model = torch.hub.load("facebookresearch/dino:main", model_name)
    model = model.to(device).eval()

    from torchvision import transforms
    preprocess = transforms.Compose([
        transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Get embed dim
    with torch.no_grad():
        dummy = torch.zeros(1, 3, 224, 224).to(device)
        embed_dim = model(dummy).shape[-1]

    return DINOModel(
        name=f"dino_{model_name.split('_')[-1]}",
        model=model,
        preprocess=preprocess,
        device=device,
        embed_dim=embed_dim,
    )


def load_beit(model_name: str, device: torch.device) -> TimmModel:
    """Load BEiT model via timm."""
    import timm

    model = timm.create_model(model_name, pretrained=True)
    model = model.to(device).eval()

    data_config = timm.data.resolve_model_data_config(model)
    preprocess = timm.data.create_transform(**data_config, is_training=False)

    embed_dim = model.embed_dim if hasattr(model, "embed_dim") else model.num_features

    return TimmModel(
        name=model_name.replace(".", "_"),
        model=model,
        preprocess=preprocess,
        device=device,
        embed_dim=embed_dim,
    )


def load_siglip(model_name: str, device: torch.device) -> CLIPModel:
    """Load SigLIP model via open_clip."""
    import open_clip

    model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained="webli")
    model = model.to(device).eval()
    tokenizer = open_clip.get_tokenizer(model_name)

    with torch.no_grad():
        dummy = torch.zeros(1, 3, 224, 224).to(device)
        embed_dim = model.encode_image(dummy).shape[-1]

    return CLIPModel(
        name=f"siglip_{model_name.lower().replace('-', '').replace('/', '_')}",
        model=model,
        preprocess=preprocess,
        tokenizer=tokenizer,
        device=device,
        embed_dim=embed_dim,
    )


def load_eva_clip(model_name: str, device: torch.device) -> CLIPModel:
    """Load EVA-CLIP model via open_clip."""
    import open_clip

    # EVA-CLIP models in open_clip
    model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained="merged2b_s8b_b131k")
    model = model.to(device).eval()
    tokenizer = open_clip.get_tokenizer(model_name)

    with torch.no_grad():
        dummy = torch.zeros(1, 3, 224, 224).to(device)
        embed_dim = model.encode_image(dummy).shape[-1]

    return CLIPModel(
        name=f"eva_clip_{model_name.lower().replace('-', '').replace('/', '_')}",
        model=model,
        preprocess=preprocess,
        tokenizer=tokenizer,
        device=device,
        embed_dim=embed_dim,
    )


def load_convnext(model_name: str, device: torch.device) -> TimmModel:
    """Load ConvNeXt model via timm (supervised baseline)."""
    import timm

    model = timm.create_model(model_name, pretrained=True)
    model = model.to(device).eval()

    data_config = timm.data.resolve_model_data_config(model)
    preprocess = timm.data.create_transform(**data_config, is_training=False)

    embed_dim = model.num_features

    return TimmModel(
        name=model_name.replace(".", "_"),
        model=model,
        preprocess=preprocess,
        device=device,
        embed_dim=embed_dim,
    )


def load_vicreg(model_name: str, device: torch.device) -> ResNetModel:
    """Load VICReg model via torch hub (explicit uniformity constraint).

    VICReg uses variance-invariance-covariance regularization which
    explicitly promotes isotropic representations.
    """
    from torchvision import transforms

    model = torch.hub.load('facebookresearch/vicreg:main', model_name)
    model = model.to(device).eval()

    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Get embed dim
    with torch.no_grad():
        dummy = torch.zeros(1, 3, 224, 224).to(device)
        embed_dim = model(dummy).shape[-1]

    return ResNetModel(
        name=f"vicreg_{model_name}",
        model=model,
        preprocess=preprocess,
        device=device,
        embed_dim=embed_dim,
    )


def load_barlow_twins(model_name: str, device: torch.device) -> ResNetModel:
    """Load Barlow Twins model via torch hub (decorrelation constraint).

    Barlow Twins uses cross-correlation loss which promotes decorrelated
    (more isotropic) representations.
    """
    from torchvision import transforms

    model = torch.hub.load('facebookresearch/barlowtwins:main', model_name)
    model = model.to(device).eval()

    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Get embed dim
    with torch.no_grad():
        dummy = torch.zeros(1, 3, 224, 224).to(device)
        embed_dim = model(dummy).shape[-1]

    return ResNetModel(
        name=f"barlow_twins_{model_name}",
        model=model,
        preprocess=preprocess,
        device=device,
        embed_dim=embed_dim,
    )


def load_moco_v3(model_name: str, device: torch.device) -> TimmModel:
    """Load MoCo v3 model (contrastive with momentum).

    Uses timm ViT as the backbone structure.
    """
    import timm
    from torchvision import transforms

    # MoCo v3 used ViT - load via timm
    model = timm.create_model(model_name, pretrained=True)
    model = model.to(device).eval()

    data_config = timm.data.resolve_model_data_config(model)
    preprocess = timm.data.create_transform(**data_config, is_training=False)

    embed_dim = model.embed_dim if hasattr(model, "embed_dim") else model.num_features

    return TimmModel(
        name=f"mocov3_{model_name.replace('.', '_')}",
        model=model,
        preprocess=preprocess,
        device=device,
        embed_dim=embed_dim,
    )


def load_swav(device: torch.device) -> ResNetModel:
    """Load SwAV model via torch hub (online clustering).

    SwAV uses swapped assignments between views, which implicitly
    promotes more uniform cluster assignments.
    """
    from torchvision import transforms

    model = torch.hub.load('facebookresearch/swav:main', 'resnet50')
    model = model.to(device).eval()

    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    return ResNetModel(
        name="swav_resnet50",
        model=model,
        preprocess=preprocess,
        device=device,
        embed_dim=2048,
    )


# Model registry
MODEL_REGISTRY = {
    # CLIP models (OpenAI)
    "clip_vitb32": lambda device: load_clip("ViT-B-32", device),
    "clip_vitb16": lambda device: load_clip("ViT-B-16", device),
    "clip_vitl14": lambda device: load_clip("ViT-L-14", device),
    # DINOv2 models
    "dinov2_vits14": lambda device: load_dinov2("dinov2_vits14", device),
    "dinov2_vitb14": lambda device: load_dinov2("dinov2_vitb14", device),
    "dinov2_vitl14": lambda device: load_dinov2("dinov2_vitl14", device),
    # DINO v1 models
    "dino_vits16": lambda device: load_dino_v1("dino_vits16", device),
    "dino_vits8": lambda device: load_dino_v1("dino_vits8", device),
    "dino_vitb16": lambda device: load_dino_v1("dino_vitb16", device),
    "dino_vitb8": lambda device: load_dino_v1("dino_vitb8", device),
    # MAE models
    "mae_vitb16": lambda device: load_mae("vit_base_patch16_224.mae", device),
    "mae_vitl16": lambda device: load_mae("vit_large_patch16_224.mae", device),
    # BEiT models
    "beit_vitb16": lambda device: load_beit("beit_base_patch16_224", device),
    "beit_vitl16": lambda device: load_beit("beit_large_patch16_224", device),
    "beitv2_vitb16": lambda device: load_beit("beitv2_base_patch16_224", device),
    "beitv2_vitl16": lambda device: load_beit("beitv2_large_patch16_224", device),
    # SimCLR (proxy - using ImageNet pretrained ResNet)
    "simclr_resnet50": load_simclr_resnet50,
    # I-JEPA (proxy - using timm ViT)
    "ijepa_vitb16": lambda device: load_ijepa("vit_base_patch16_224", device),
    "ijepa_vitl16": lambda device: load_ijepa("vit_large_patch16_224", device),
    "ijepa_vith14": lambda device: load_ijepa("vit_huge_patch14_224", device),
    # SigLIP models (Google's CLIP alternative)
    "siglip_vitb16": lambda device: load_siglip("ViT-B-16-SigLIP", device),
    "siglip_vitl16": lambda device: load_siglip("ViT-L-16-SigLIP", device),
    # EVA-CLIP models
    "eva_clip_vitb16": lambda device: load_eva_clip("EVA02-B-16", device),
    "eva_clip_vitl14": lambda device: load_eva_clip("EVA02-L-14", device),
    # ConvNeXt (supervised baseline)
    "convnext_base": lambda device: load_convnext("convnext_base.fb_in22k_ft_in1k", device),
    "convnext_large": lambda device: load_convnext("convnext_large.fb_in22k_ft_in1k", device),
    # DINOv2 with registers (potentially different geometry)
    "dinov2_vitb14_reg": lambda device: load_dinov2("dinov2_vitb14_reg", device),
    "dinov2_vitl14_reg": lambda device: load_dinov2("dinov2_vitl14_reg", device),
    # EVA (MIM pretrained, expect high anisotropy like MAE)
    "eva_vitb16": lambda device: load_beit("eva_base_patch16_224.in22k_ft_in1k", device),
    "eva_vitl14": lambda device: load_beit("eva_large_patch14_336.in22k_ft_in1k", device),
    # CLIP variants for isotropy comparison
    "clip_convnext_base": lambda device: load_clip("convnext_base_w", device),
    # === ISOTROPY-TARGETING METHODS (explicit uniformity constraints) ===
    # VICReg - variance-invariance-covariance regularization
    "vicreg_resnet50": lambda device: load_vicreg("resnet50", device),
    "vicreg_resnet50x2": lambda device: load_vicreg("resnet50x2", device),
    # Barlow Twins - cross-correlation / decorrelation loss
    "barlow_twins_resnet50": lambda device: load_barlow_twins("resnet50", device),
    # SwAV - swapped assignments (implicit uniformity via clustering)
    "swav_resnet50": load_swav,
}


def load_model(model_key: str, device: torch.device) -> SSLModel:
    """Load a model by key from the registry."""
    if model_key not in MODEL_REGISTRY:
        available = ", ".join(sorted(MODEL_REGISTRY.keys()))
        raise ValueError(f"Unknown model: {model_key}. Available: {available}")
    return MODEL_REGISTRY[model_key](device)


def list_models() -> List[str]:
    """List available model keys."""
    return sorted(MODEL_REGISTRY.keys())
