# Based on implementations from the 4M repo: https://github.com/apple/ml-4m/
import math
from functools import partial
from typing import Optional, Union

import torch
from torch import nn

from fourm.utils.timm.registry import register_model

from .encoder_embeddings import ImageEncoderEmbedding
from .fm_utils import Block, LayerNorm


__all__ = [
    # GELU models
    'fm_vit_tiny_6e_gelu',
    'fm_vit_small_8e_gelu',
    'fm_vit_base_12e_gelu',
    'fm_vit_large_24e_gelu',
    'fm_vit_xlarge_24e_gelu',
    # SwiGLU models
    'fm_vit_tiny_6e_swiglu_nobias',
    'fm_vit_small_8e_swiglu_nobias',
    'fm_vit_base_12e_swiglu_nobias',
    'fm_vit_large_24e_swiglu_nobias',
    'fm_vit_xlarge_24e_swiglu_nobias',
    # SwiGLU + QKNorm models
    'fm_vit_base_12e_swiglu_qknorm_nobias',
    'fm_vit_large_24e_swiglu_qknorm_nobias',
    'fm_vit_xlarge_24e_swiglu_qknorm_nobias',
]

class FourMViT(nn.Module):
    """Modified 4M model, adapted to behave as a simple RGB-only ViT.

    Args:
        img_size (int): Input image size.
        patch_size (int): Patch size.
        in_chans (int): Number of input image channels.
        dim (int): Patch embedding dimension.
        encoder_depth (int): Depth of ViT / number of encoder blocks.
        num_heads (int): Number of attention heads in each ViT block.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool): If True, add a learnable bias to query, key, value.
        proj_bias (bool): If True, adds a bias to the attention out proj layer.
        mlp_bias (bool): If True, adds a learnable bias for the feedforward.
        drop_path_rate (float): Stochastic depth rate.
        drop_rate (float): Dropout rate.
        attn_drop_rate (float): Attention dropout rate.
        act_layer (nn.Module): Activation layer.
        norm_layer (nn.Module): Normalization layer.
        gated_mlp (bool): If True, makes the feedforward gated (e.g., for SwiGLU)
        qk_norm (bool): If True, normalizes the query and keys (as in ViT-22B)
        encoder_norm (bool): If True, adds a norm layer after the last encoder block.
        output_head (Optional[nn.Module]): Optional output head after the encoder
    """
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        dim=768,
        encoder_depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias: bool = True,
        proj_bias: bool = True,
        mlp_bias: bool = True,
        drop_path_rate: float =0.0,
        drop_rate: float = 0.0,
        attn_drop_rate: float =0.0,
        act_layer: torch.Tensor =nn.GELU,
        norm_layer: Union[partial, nn.Module] = partial(LayerNorm, eps=1e-6),
        gated_mlp: bool = False, # Make the feedforward gated for e.g. SwiGLU
        qk_norm: bool = False,
        encoder_norm = True,
        output_head: Optional[nn.Module] = None,
    ):
        super().__init__()
        self.img_size = img_size
        self.init_std = 0.02
        rgb_embedding = ImageEncoderEmbedding(num_channels=in_chans, patch_size=patch_size,
                                              dim_tokens=dim, sincos_pos_emb=True, image_size=img_size)
        self.num_patches = rgb_embedding.num_patches
        self.encoder_embeddings = nn.ModuleDict({f"rgb@{img_size}": rgb_embedding})

        # stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, encoder_depth)]

        self.encoder = nn.ModuleList([
            Block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, mlp_bias=mlp_bias,
                 drop_path=dpr[i], drop=drop_rate, attn_drop=attn_drop_rate, act_layer=act_layer, norm_layer=norm_layer, 
                 gated_mlp=gated_mlp, qk_norm=qk_norm)
            for i in range(encoder_depth)
        ])

        self.encoder_norm = norm_layer(dim) if encoder_norm else nn.Identity()

        # Weight init
        self.init_weights()

        # Classification head is initialized after init_weights() to allow for special init scale
        if output_head is not None:
            self.output_head = output_head
            if hasattr(self.output_head, 'init'):
                self.output_head.init(dim)
        else:
            self.output_head = nn.Identity()

    def init_weights(self):
        """Weight initialization following MAE's initialization scheme"""

        for name, m in self.named_modules():
            # Skipping tokenizers to avoid reinitializing them
            if "tokenizer" in name:
                continue
            # Linear
            elif isinstance(m, nn.Linear):
                if 'qkv' in name:
                    # treat the weights of Q, K, V separately
                    val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
                    nn.init.uniform_(m.weight, -val, val)
                elif 'kv' in name:
                    # treat the weights of K, V separately
                    val = math.sqrt(6. / float(m.weight.shape[0] // 2 + m.weight.shape[1]))
                    nn.init.uniform_(m.weight, -val, val)
                else:
                    nn.init.xavier_uniform_(m.weight)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            # LayerNorm
            elif isinstance(m, nn.LayerNorm) or isinstance(m, LayerNorm):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0)

            # Embedding
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, std=self.init_std)
            # Conv2d
            elif isinstance(m, nn.Conv2d):
                if '.proj' in name:
                    # From MAE, initialize projection like nn.Linear (instead of nn.Conv2d)
                    w = m.weight.data
                    nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

    def get_num_layers_encoder(self):
        return len(self.encoder)
    
    def get_num_layers(self):
        return self.get_num_layers_encoder()
    
    @torch.jit.ignore
    def no_weight_decay(self):
        no_wd_set = set()

        for mod, emb_module in self.encoder_embeddings.items():
            if hasattr(emb_module, 'no_weight_decay'):
                to_skip = emb_module.no_weight_decay()
                to_skip = set([f'encoder_embeddings.{mod}.{name}' for name in to_skip])
                no_wd_set = no_wd_set | to_skip

        return no_wd_set
    

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the model.

        Args:
            x (torch.Tensor): Input tensor. Shape (B, C, H, W)

        Returns:
            torch.Tensor: Output tensor. Shape (B, num_classes).
        """
        rgb_dict = {'tensor': x}
        rgb_dict = self.encoder_embeddings[f'rgb@{self.img_size}'](rgb_dict)

        # Add embeddings to patchified RGB image 
        x = rgb_dict['x'] + rgb_dict['emb'] # Shape: (B, N, D) with N = num_patches

        for blk in self.encoder:
            x = blk(x)

        x = self.encoder_norm(x) # Shape: (B, N, D) 

        out = self.output_head(x)

        return out
    
    
    def freeze_encoder(self, freeze_embeddings=True):
        for param in self.encoder.parameters():
            param.requires_grad = False

        for param in self.encoder_norm.parameters():
            param.requires_grad = False

        if freeze_embeddings:
            for param in self.encoder_embeddings.parameters():
                param.requires_grad = False

    def unfreeze_encoder(self, unfreeze_embeddings=True):
        for param in self.encoder.parameters():
            param.requires_grad = True

        for param in self.encoder_norm.parameters():
            param.requires_grad = True

        if unfreeze_embeddings:
            for param in self.encoder_embeddings.parameters():
                param.requires_grad = True

    def check_encoder_weights_match(self, checkpoint_state_dict):
        """Check if encoder and encoder_embeddings weights match with checkpoint weights.
        
        Args:
            checkpoint_state_dict (dict): State dict from checkpoint to compare against
            
        Returns:
            tuple: (bool, dict) - Whether weights match and dict containing mismatch details
        """
        model_state_dict = self.state_dict()
        mismatch_info = {
            'missing_keys': [],
            'unexpected_keys': [],
            'shape_mismatches': [],
            'value_mismatches': []
        }
        
        # Check encoder weights
        for name, param in self.encoder.named_parameters():
            full_name = f'encoder.{name}'
            if full_name not in checkpoint_state_dict:
                mismatch_info['missing_keys'].append(full_name)
                continue
                
            checkpoint_param = checkpoint_state_dict[full_name]
            if param.shape != checkpoint_param.shape:
                mismatch_info['shape_mismatches'].append((full_name, param.shape, checkpoint_param.shape))
            elif not torch.allclose(param, checkpoint_param):
                mismatch_info['value_mismatches'].append(full_name)
                
        # Check encoder_norm weights
        for name, param in self.encoder_norm.named_parameters():
            full_name = f'encoder_norm.{name}'
            if full_name not in checkpoint_state_dict:
                mismatch_info['missing_keys'].append(full_name)
                continue
                
            checkpoint_param = checkpoint_state_dict[full_name]
            if param.shape != checkpoint_param.shape:
                mismatch_info['shape_mismatches'].append((full_name, param.shape, checkpoint_param.shape))
            elif not torch.allclose(param, checkpoint_param):
                mismatch_info['value_mismatches'].append(full_name)
                
        # Check encoder_embeddings weights
        for name, param in self.encoder_embeddings.named_parameters():
            full_name = f'encoder_embeddings.{name}'
            if full_name not in checkpoint_state_dict:
                mismatch_info['missing_keys'].append(full_name)
                continue
                
            checkpoint_param = checkpoint_state_dict[full_name]
            if param.shape != checkpoint_param.shape:
                mismatch_info['shape_mismatches'].append((full_name, param.shape, checkpoint_param.shape))
            elif not torch.allclose(param, checkpoint_param):
                mismatch_info['value_mismatches'].append(full_name)
                
        # Check for unexpected keys in checkpoint
        for key in checkpoint_state_dict:
            if key.startswith(('encoder.', 'encoder_norm.', 'encoder_embeddings.')):
                if key not in model_state_dict:
                    mismatch_info['unexpected_keys'].append(key)
                    
        # Check if there are any mismatches
        has_mismatches = any(len(v) > 0 for v in mismatch_info.values())
        
        return not has_mismatches, mismatch_info

# Model definitions
                
# GELU variants
@register_model
def fm_vit_tiny_6e_gelu(**kwargs):
    model = FourMViT(
        encoder_depth=6,
        dim=384,
        num_heads=6,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs
    )
    return model


@register_model
def fm_vit_small_8e_gelu(**kwargs):
    model = FourMViT(
        encoder_depth=8,
        dim=512,
        num_heads=8,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs
    )
    return model


@register_model
def fm_vit_base_12e_gelu(**kwargs):
    model = FourMViT(
        encoder_depth=12,
        dim=768,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs
    )
    return model


@register_model
def fm_vit_large_24e_gelu(**kwargs):
    model = FourMViT(
        encoder_depth=24,
        dim=1024,
        num_heads=16,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs
    )
    return model

@register_model
def fm_vit_xlarge_24e_gelu(**kwargs):
    model = FourMViT(
        encoder_depth=24,
        dim=2048,
        num_heads=32,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs
    )
    return model


# SwiGLU variants
@register_model
def fm_vit_tiny_6e_swiglu_nobias(**kwargs):
    model = FourMViT(
        encoder_depth=6,
        dim=384,
        num_heads=6,
        mlp_ratio=4,
        qkv_bias=False,
        proj_bias=False,
        mlp_bias=False,
        norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
        act_layer=nn.SiLU,
        gated_mlp=True,
        **kwargs
    )
    return model


@register_model
def fm_vit_small_8e_swiglu_nobias(**kwargs):
    model = FourMViT(
        encoder_depth=8,
        dim=512,
        num_heads=8,
        mlp_ratio=4,
        qkv_bias=False,
        proj_bias=False,
        mlp_bias=False,
        norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
        act_layer=nn.SiLU,
        gated_mlp=True,
        **kwargs
    )
    return model


@register_model
def fm_vit_base_12e_swiglu_nobias(**kwargs):
    model = FourMViT(
        encoder_depth=12,
        dim=768,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=False,
        proj_bias=False,
        mlp_bias=False,
        norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
        act_layer=nn.SiLU,
        gated_mlp=True,
        **kwargs
    )
    return model


@register_model
def fm_vit_large_24e_swiglu_nobias(**kwargs):
    model = FourMViT(
        encoder_depth=24,
        dim=1024,
        num_heads=16,
        mlp_ratio=4,
        qkv_bias=False,
        proj_bias=False,
        mlp_bias=False,
        norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
        act_layer=nn.SiLU,
        gated_mlp=True,
        **kwargs
    )
    return model

@register_model
def fm_vit_xlarge_24e_swiglu_nobias(**kwargs):
    model = FourMViT(
        encoder_depth=24,
        dim=2048,
        num_heads=32,
        mlp_ratio=4,
        qkv_bias=False,
        proj_bias=False,
        mlp_bias=False,
        norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
        act_layer=nn.SiLU,
        gated_mlp=True,
        **kwargs
    )
    return model

# SwiGLU + QKNorm variants

@register_model
def fm_vit_base_12e_swiglu_qknorm_nobias(**kwargs):
    model = FourMViT(
        encoder_depth=12,
        dim=768,
        num_heads=12,
        mlp_ratio=4,
        qkv_bias=False,
        proj_bias=False,
        mlp_bias=False,
        norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
        act_layer=nn.SiLU,
        gated_mlp=True,
        qk_norm=True,
        **kwargs
    )
    return model


@register_model
def fm_vit_large_24e_swiglu_qknorm_nobias(**kwargs):
    model = FourMViT(
        encoder_depth=24,
        dim=1024,
        num_heads=16,
        mlp_ratio=4,
        qkv_bias=False,
        proj_bias=False,
        mlp_bias=False,
        norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
        act_layer=nn.SiLU,
        gated_mlp=True,
        qk_norm=True,
        **kwargs
    )
    return model

@register_model
def fm_vit_xlarge_24e_swiglu_qknorm_nobias(**kwargs):
    model = FourMViT(
        encoder_depth=24,
        dim=2048,
        num_heads=32,
        mlp_ratio=4,
        qkv_bias=False,
        proj_bias=False,
        mlp_bias=False,
        norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
        act_layer=nn.SiLU,
        gated_mlp=True,
        qk_norm=True,
        **kwargs
    )
    return model