import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import torch
import matplotlib.pyplot as plt

# from models.aggregators.ddfnet.ddf.ddf import FilterNorm, DDFFunction
from models.backbones.ddfnet.ddf.ddf import FilterNorm, DDFFunction

# Predefined embedding dimensions for DINOv2 variants
DINOV2_ARCHS = {
    'dinov2_vits14': 384,
    'dinov2_vitb14': 768,
    'dinov2_vitl14': 1024,
    'dinov2_vitg14': 1536,
}

class DAAdapter(nn.Module):
    def __init__(self, embed_dim, hidden_dim, kernel_size, scale_factor=1.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.k = kernel_size
        self.scale = scale_factor

        # Low-rank projection branch
        self.down = nn.Linear(embed_dim, hidden_dim, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.up = nn.Linear(hidden_dim, embed_dim, bias=False)

        # static 3x3 convolution on low-rank channels (r -> r)
        self.static = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1)

        # Spatial branch: generate per-location filters
        self.sp_conv = nn.Conv2d(hidden_dim, self.k * self.k, kernel_size=1, bias=True)
        self.sp_filternorm = FilterNorm(1, kernel_size, 'spatial', 'relu')

        # Channel branch: generate per-channel filters
        self.ch_fc = nn.Linear(hidden_dim, hidden_dim * self.k * self.k, bias=True)
        self.ch_filternorm = FilterNorm(hidden_dim, kernel_size, 'channel', 'relu', running_std=True)

        # Learnable fusion weight
        self.gamma = nn.Parameter(torch.tensor(0.5))


    def forward(self, tokens):
        cls_tok, patches = tokens[:, :1, :], tokens[:, 1:, :]
        B, P, D = patches.shape
        H = W = int(P ** 0.5)
        r = self.hidden_dim

        # 1) Low-rank down projection
        x_down = self.relu(self.down(patches))     # [B, P, r]

        # 2) Spatial feature reshape
        feat = x_down.view(B, H, W, r).permute(0, 3, 1, 2).contiguous()  # [B, r, H, W]

        out_static = self.static(feat)  # [B, r, H, W]

        # 3) Spatial attention map
        sp_f = self.sp_conv(feat)  # [B, k*k, H, W]
        sp_filter = self.sp_filternorm(sp_f).contiguous()  # [B, k*k, H, W]

        # 4) Channel filter generation
        ch_avg = F.adaptive_avg_pool2d(feat, (1, 1)).view(B, r)           # [B, r]
        ch_f = self.ch_fc(ch_avg)                                        # [B, r*k*k]
        ch_filter = self.ch_filternorm(ch_f.view(B, r, self.k*self.k))\
                    .view(B, r, self.k, self.k)\
                    .contiguous()  # [B, r, k, k]

        out_dyn = DDFFunction.apply(feat, ch_filter, sp_filter, self.k, 1, 1, 'mul')

        # 5) Static conv branch
        out_tok = out_static + out_dyn  # [B, P, r]
        out_tok = out_tok.view(B, r, -1).permute(0, 2, 1)  # [B, P, r]

        # 8) Merge and up-project
        tok_up = self.up(out_tok)                      # [B, P, D]

        # residual fusion
        fused = patches + self.gamma * tok_up

        return torch.cat((cls_tok, fused), dim=1)

class DINOv2DAAdapter(nn.Module):
    """
    DINOv2 backbone + Dynamic Adapter for dynamic feature aggregation.

    Args:
        model_name (str): backbone name
        hidden_dim (int): bottleneck dimension r for down/up and hypernet
        kernel_size (int): conv kernel size
        dropout (float): dropout rate for HOM outputs
        norm_layer (bool): apply final norm
        return_token (bool): return class token
        prompt_len (int): number of prompt tokens T
        scale_factor (float): scale for static conv branch
        num_hom_layers (int): apply HOM on last L layers; if None, apply on all
    """
    def __init__(
        self,
        model_name='dinov2_vitb14_da',
        hidden_dim=64,
        kernel_size=3,
        dropout=0.1,
        norm_layer=False,
        return_token=False,
        scale_factor=1.0,
        num_da_layers=2
    ):
        super().__init__()
        # load backbone
        base = model_name.replace('_da', '')
        assert base in DINOV2_ARCHS, f"Unknown model {base}"
        self.embed_dim = DINOV2_ARCHS[base]
        self.norm_layer = norm_layer
        self.return_token = return_token
        self.scale_factor = scale_factor

        # number of blocks to apply HOM
        total_blocks = len(getattr(torch.hub.load('facebookresearch/dinov2', base), 'blocks'))
        self.num_da_layers = total_blocks if num_da_layers is None else num_da_layers
        assert 0 < self.num_da_layers <= total_blocks

        self.model = torch.hub.load('facebookresearch/dinov2', base)
        for p in self.model.parameters():
            p.requires_grad = False

        # create DDF-based HOM modules
        self.da_ffn = nn.ModuleList([
            DAAdapter(
                embed_dim=self.embed_dim,
                hidden_dim=hidden_dim,
                kernel_size=kernel_size,
                scale_factor=self.scale_factor
            ) for _ in range(num_da_layers)
        ])
        self.dropout = nn.Dropout(dropout)

        # parameter summary
        total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"Total trainable parameters: {total_trainable/1e6:.3f}M")
        print("Parameter counts per module:")
        for name, module in self.named_children():
            trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
            total = sum(p.numel() for p in module.parameters())
            print(f"  {name}: {trainable/1e6:.3f}M trainable, {total/1e6:.3f}M total")

    def forward(self, x):
        B, C, H, W = x.shape
        # prepare tokens
        tokens = self.model.prepare_tokens_with_masks(x)
        blocks = self.model.blocks
        n_blocks = len(blocks)
        cutoff = n_blocks - self.num_da_layers
        da_idx = 0
        # iterate layers
        for i, blk in enumerate(blocks):
            if i < cutoff:
                # original block forward
                tokens = blk(tokens)
            else:
                # --- Feed-Forward Network with correct residual ---
                norm2 = blk.norm2(tokens)
                ffn_out = blk.mlp(norm2)
                ffn_branch = tokens + blk.drop_path2(ffn_out)
                da = self.da_ffn[da_idx](norm2)
                da = self.dropout(da)
                tokens = ffn_branch + da
                da_idx += 1
        # optional final norm
        if self.norm_layer:
            tokens = self.model.norm(tokens)

        # separate class token & feature patches
        t = tokens[:, 0] if self.return_token else None
        patch_size = self.model.patch_embed.patch_size
        if isinstance(patch_size, (tuple, list)):
            patch_size = patch_size[0]
        npatch = (H // patch_size) * (W // patch_size)
        patches = tokens[:, 1:1 + npatch]
        feat = patches.reshape(B, H // patch_size, W // patch_size, self.embed_dim)
        feat = feat.permute(0, 3, 1, 2)
        return (feat, t) if self.return_token else feat
  


if __name__ == '__main__':

    # Define parameters for the test
    batch_size = 72
    embed_dim = 768
    hidden_dim = 128
    kernel_size = 3
    prompt_len = 20
    num_patches = 16*16  # Assuming 16x16 patches for a 224x224 image

    # Create a dummy tokens tensor with a class token and patch tokens
    dummy_tokens = torch.randn(batch_size, num_patches + 1, embed_dim)

    # Instantiate the shared hypernetwork and the VFMHOM module
    shared_hypernet = SharedHypernet(embed_dim, hidden_dim, kernel_size)
    vfmhom_module = VFMHOM(embed_dim, kernel_size, shared_hypernet, prompt_len=prompt_len, scale_factor=1.0)

    # Run a forward pass to test VFMHOM
    output_tokens = vfmhom_module(dummy_tokens)
    print("Output tokens shape:", output_tokens.shape)