# models/openclip_vit.py

import os
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as cp
import open_clip


def interpolate_pos_embed_if_needed(vit_model, target_img_size):
    visual = vit_model.model.visual
    if hasattr(visual, "positional_embedding") and hasattr(visual, "grid_size") and hasattr(visual, "image_size"):
        old_size = visual.image_size if isinstance(visual.image_size, tuple) else (visual.image_size, visual.image_size)
        tgt = (target_img_size, target_img_size)
        if old_size != tgt:
            print(f"[INFO] Interpolating pos_embed: {old_size} → {tgt}")
            pos = visual.positional_embedding  # [N+1, C]
            cls_tok, patch_pos = pos[:1], pos[1:]
            og_g = visual.grid_size
            new_g = (tgt[0] // visual.patch_size[0], tgt[1] // visual.patch_size[1])
            p = patch_pos.reshape(1, og_g[0], og_g[1], -1).permute(0, 3, 1, 2)
            p = F.interpolate(p, size=new_g, mode='bicubic', align_corners=False)
            p = p.permute(0, 2, 3, 1).reshape(-1, p.shape[1])
            visual.positional_embedding.data = torch.cat([cls_tok, p], dim=0)
            visual.grid_size = new_g
            visual.image_size = tgt
            print(f"[INFO] New pos_embed shape: {visual.positional_embedding.shape}")
    else:
        print("[WARN] No pos_embed attributes, skip interpolation")


def load_mae_weights_to_openclip(openclip_model, mae_path):
    """Intelligently load MAE's ViT backbone weights into OpenCLIP visual backbone, only loading structurally corresponding parts."""
    print(f"[INFO] Loading MAE ViT backbone weights from {mae_path} into OpenCLIP backbone...")
    mae_ckpt = torch.load(mae_path, map_location="cpu")
    mae_state_dict = mae_ckpt['model'] if 'model' in mae_ckpt else mae_ckpt

    # Map MAE weights to OpenCLIP visual backbone (patch_embed/blocks/norm/cls_token/pos_embed, etc.)
    visual = openclip_model.model.visual
    oc_sd = visual.state_dict()
    mapping = {}
    for k in oc_sd:
        # Typical naming correspondence rules: patch_embed.*/norm/cls_token/pos_embed <-> embedding/cls_token/positional_embedding
        mk = k
        if mk in mae_state_dict:
            mapping[k] = mae_state_dict[mk]
        elif "patch_embed" in mk and "embedding" in oc_sd:
            mk2 = mk.replace("patch_embed", "embedding")
            if mk2 in mae_state_dict:
                mapping[k] = mae_state_dict[mk2]
        elif "positional_embedding" in mk or "pos_embed" in mk:
            for key in mae_state_dict:
                if "pos" in key and mae_state_dict[key].shape == oc_sd[k].shape:
                    mapping[k] = mae_state_dict[key]
        elif "cls_token" in mk:
            for key in mae_state_dict:
                if "cls_token" in key and mae_state_dict[key].shape == oc_sd[k].shape:
                    mapping[k] = mae_state_dict[key]
        # More detailed key matching is needed for blocks and transformer layers
        elif "transformer.resblocks" in mk:
            # openclip: transformer.resblocks.X.X   mae: blocks.X.X
            idx = mk.split('.')[2]
            mk2 = mk.replace("transformer.resblocks." + idx, f"blocks.{idx}")
            if mk2 in mae_state_dict:
                mapping[k] = mae_state_dict[mk2]
        elif "norm" in mk and mk in mae_state_dict:
            mapping[k] = mae_state_dict[mk]
    # Update weights
    missing, unexpected = visual.load_state_dict(mapping, strict=False)
    print(f"[INFO] MAE weights loaded, missing={missing}, unexpected={unexpected}")


class OpenCLIPViT(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        raw_pretrained = cfg.get('pretrained', 'openai')
        # Determine pretrained_flag and local weights
        if os.path.isfile(raw_pretrained):
            if raw_pretrained.endswith('.pth'):
                pretrained_flag = None
                local_weights = raw_pretrained
            else:
                pretrained_flag = raw_pretrained
                local_weights = None
        elif isinstance(raw_pretrained, str) and raw_pretrained.lower() in ('none', ''):
            pretrained_flag = None
            local_weights = None
        else:
            pretrained_flag = raw_pretrained
            local_weights = None

        # Create model (without loading built-in weights)
        self.model, _, _ = open_clip.create_model_and_transforms(
            cfg['arch'], pretrained=pretrained_flag,
            cache_dir=cfg.get('cache_dir', None),
            precision='fp32'
        )

        # If it's a PyTorch self-supervised weight, map manually
        if local_weights and local_weights.endswith('.pth'):
            load_mae_weights_to_openclip(self, local_weights)
        # If it's OpenCLIP's own weight, directly load_state_dict
        elif local_weights:
            state = torch.load(local_weights, map_location='cpu')
            sd = state.get('model', state)
            missing, unexpected = self.model.load_state_dict(sd, strict=False)
            print(
                f"[INFO] Loaded local weights from {local_weights}, missing keys: {missing}, unexpected: {unexpected}")

        # Infer embed_dim
        vis = self.model.visual
        if hasattr(vis, 'output_dim'):
            self.embed_dim = vis.output_dim
        elif hasattr(vis, 'proj'):
            self.embed_dim = vis.proj.shape[0]
        elif hasattr(vis, 'projection'):
            self.embed_dim = vis.projection.shape[0]
        elif hasattr(vis, 'embed_dim'):
            self.embed_dim = vis.embed_dim
        else:
            raise RuntimeError("Cannot infer embed_dim from OpenCLIP model")

    def forward(self, x):
        def run_vis(inp):
            return self.model.visual(inp)

        return cp.checkpoint(run_vis, x)

    def unfreeze_last_n_blocks(self, n: int):
        """Unfreeze the last n transformer blocks"""
        visual = self.model.visual

        # First freeze all parameters
        for param in visual.parameters():
            param.requires_grad = False

        # Find transformer blocks
        transformer = getattr(visual, 'transformer', None)
        if transformer is None:
            print("[WARN] No transformer found in visual model")
            return

        # Try different block attribute names
        blocks = None
        for attr_name in ['resblocks', 'blocks', 'layers']:
            if hasattr(transformer, attr_name):
                blocks = getattr(transformer, attr_name)
                print(f"[INFO] Found transformer blocks: {attr_name}")
                break

        if blocks is None:
            print("[WARN] No transformer blocks found")
            return

        total_blocks = len(blocks)
        n = min(n, total_blocks)  # Ensure not exceeding the total number of blocks

        print(f"[INFO] Unfreezing last {n} blocks out of {total_blocks} total blocks")

        # Unfreeze the last n blocks
        for idx in range(total_blocks - n, total_blocks):
            block = blocks[idx]
            for param in block.parameters():
                param.requires_grad = True
            print(f"  ✓ Unfroze block {idx}")

        # Unfreeze the final layer norm and projection (if exist)
        if hasattr(visual, 'ln_post'):
            for param in visual.ln_post.parameters():
                param.requires_grad = True
            print("  ✓ Unfroze ln_post")

        if hasattr(visual, 'proj') and visual.proj is not None:
            visual.proj.requires_grad = True
            print("  ✓ Unfroze projection")

        # Count trainable parameters
        trainable_params = sum(p.numel() for p in visual.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in visual.parameters())
        print(
            f"  ✓ Vision encoder: {trainable_params:,}/{total_params:,} trainable ({100 * trainable_params / total_params:.1f}%)")


def build_openclip_vit(cfg):
    vit = OpenCLIPViT(cfg)
    interpolate_pos_embed_if_needed(vit, cfg.get('img_size', 224))
    return vit