import torch
import torch.nn as nn
import os
from timm.models import create_model
from timm.models.vision_transformer import VisionTransformer
from collections import OrderedDict

def get_model_embed_dim(model_name: str) -> int:
    """Get the embedding dimension for a given model name"""
    embed_dims = {
        'deit_tiny_patch16_224': 192,
        'deit_small_patch16_224': 384,
        'deit_base_patch16_224': 768,
        'deit_base_patch16_384': 768,
        'deit_large_patch16_224': 1024,
        'deit_huge_patch14_224': 1280,
    }
    return embed_dims.get(model_name, 768)  # Default to base model


def check_model_compatibility(model_name: str, pretrained_model_name: str) -> bool:
    """Check if two models are compatible for weight transfer"""
    return get_model_embed_dim(model_name) == get_model_embed_dim(pretrained_model_name)

class MyDeiT(VisionTransformer):
    def __init__(self, model_name='deit_tiny_patch16_224', pretrained=False, num_classes=1000, ckpt_path=None, **kwargs):
        model = create_model(model_name, pretrained=False, num_classes=num_classes, **kwargs)
        if pretrained and model_name == 'deit_tiny_patch16_224' and ckpt_path is None:
            checkpoint_path = '/Users/aiot/.cache/torch/hub/checkpoints/deit_tiny_patch16_224-a1311bcf.pth'
            
            # Check if local checkpoint exists, otherwise use online pretrained
            if os.path.exists(checkpoint_path):
                print(f"🔍 Loading weights from local checkpoint: {checkpoint_path}")
                checkpoint = torch.load(checkpoint_path, map_location='cpu')
                state_dict = checkpoint.get('model', checkpoint)
                filtered_dict = OrderedDict()
                for k, v in state_dict.items():
                    if not k.startswith('head.'): 
                        filtered_dict[k] = v
                model.load_state_dict(filtered_dict, strict=False)
                print("🆕 Initialized MyDeiT with local deit_tiny_patch16_224 weights.")
            else:
                print(f"⚠️ Local checkpoint not found at {checkpoint_path}")
                print("🌐 Using online pretrained weights instead...")
                # Create model with online pretrained weights
                model = create_model(model_name, pretrained=True, num_classes=num_classes, **kwargs)
                print("🆕 Initialized MyDeiT with online deit_tiny_patch16_224 weights.")
        try:
            img_size = model.patch_embed.img_size
        except AttributeError:
            img_size = (224, 224) if "384" not in model_name else (384, 384)

        super().__init__(
            img_size=img_size,
            patch_size=model.patch_embed.patch_size,
            in_chans=model.patch_embed.proj.in_channels,
            num_classes=num_classes,
            embed_dim=model.embed_dim,
            depth=len(model.blocks),
            num_heads=model.blocks[0].attn.num_heads,
            qkv_bias=True,
            norm_layer=nn.LayerNorm,
        )

        self.patch_embed = model.patch_embed
        self.cls_token = model.cls_token
        self.dist_token = getattr(model, 'dist_token', None)
        self.pos_embed = model.pos_embed
        self.pos_drop = model.pos_drop
        self.blocks = model.blocks
        self.norm = model.norm
        self.head = model.head
        self.head_dist = getattr(model, 'head_dist', None)

        if ckpt_path is not None:
            state_dict = torch.load(ckpt_path, map_location='cpu')
            if "state_dict" in state_dict: 
                state_dict = state_dict["state_dict"]
                state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
            self.load_state_dict(state_dict, strict=False)
            print(f"✅ Loaded checkpoint from {ckpt_path}")
        elif not pretrained:
            print("🆕 Initialized MyDeiT with random weights.")

            
    @torch.no_grad()
    def encode(self, x):
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)  # [B, 1, D]
        x = torch.cat((cls_tokens, x), dim=1)  # [B, 1 + N, D]
        x = x + self.pos_embed[:, :x.size(1), :]
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x[:, 0]  # [B, D]

# Factory functions for different DeiT variants
def create_deit_tiny(num_classes: int = 1000, pretrained: bool = True, **kwargs) -> MyDeiT:
    """Create DeiT-Tiny model"""
    return MyDeiT(
        model_name='deit_tiny_patch16_224',
        num_classes=num_classes,
        pretrained=pretrained,
        **kwargs
    )


def create_deit_small(num_classes: int = 1000, pretrained: bool = True, ckpt_path=None, **kwargs) -> MyDeiT:
    """Create DeiT-Small model"""
    return MyDeiT(
        model_name='deit_small_patch16_224',
        num_classes=num_classes,
        pretrained=pretrained,
        ckpt_path=ckpt_path,
        **kwargs
    )


def create_deit_base(num_classes: int = 1000, pretrained: bool = True, ckpt_path=None, **kwargs) -> MyDeiT:
    """Create DeiT-Base model"""
    return MyDeiT(
        model_name='deit_base_patch16_224',
        num_classes=num_classes,
        pretrained=pretrained,
        ckpt_path=ckpt_path,
        **kwargs
    )


def create_deit_base_384(num_classes: int = 1000, pretrained: bool = True, ckpt_path=None, **kwargs) -> MyDeiT:
    """Create DeiT-Base model with 384x384 input"""
    return MyDeiT(
        model_name='deit_base_patch16_384',
        num_classes=num_classes,
        pretrained=pretrained,
        ckpt_path=ckpt_path,
        **kwargs
    )


def create_mydeit_from_config(config_dict: dict) -> MyDeiT:
    """Create MyDeiT model from configuration dictionary"""
    model_name = config_dict.get('model_name', 'deit_base_patch16_224')
    num_classes = config_dict.get('num_classes', 1000)
    pretrained = config_dict.get('pretrained', True)
    
    # Handle the case where pretrained weights might not match
    if pretrained and 'force_pretrained_model' in config_dict:
        # User explicitly wants to load from a specific pretrained model
        target_model = config_dict['force_pretrained_model']
        if not check_model_compatibility(model_name, target_model):
            print(f"Warning: Model size mismatch between {model_name} and {target_model}")
            print("Creating model without pretrained weights...")
            pretrained = False
    
    model = MyDeiT(
        model_name=model_name,
        num_classes=num_classes,
        pretrained=pretrained,
        **{k: v for k, v in config_dict.items() 
           if k not in ['model_name', 'num_classes', 'pretrained', 'force_pretrained_model']}
    )
    
    # If we disabled pretrained loading due to mismatch, try to load compatible weights
    if not pretrained and 'force_pretrained_model' in config_dict:
        model.load_pretrained_weights(config_dict['force_pretrained_model'], strict=False)
    
    return model


# Example usage and testing
if __name__ == "__main__":
    print("Testing MyDeiT with proper model loading...")
    
    # Test 1: Create models with matching pretrained weights
    print("\n=== Test 1: Correct model-weight matching ===")
    try:
        tiny_model = create_deit_tiny(num_classes=1000, pretrained=True)
        print(f"✓ DeiT-Tiny loaded successfully, embed_dim: {tiny_model.embed_dim}")
    except Exception as e:
        print(f"❌ DeiT-Tiny loading failed: {e}")
    
    try:
        base_model = create_deit_base(num_classes=1000, pretrained=True)
        print(f"✓ DeiT-Base loaded successfully, embed_dim: {base_model.embed_dim}")
    except Exception as e:
        print(f"❌ DeiT-Base loading failed: {e}")
    
    # Test 2: Handle size mismatches gracefully
    print("\n=== Test 2: Handling size mismatches ===")
    try:
        base_model_no_pretrain = create_deit_base(num_classes=10, pretrained=False)
        print(f"✓ DeiT-Base without pretrained weights: {base_model_no_pretrain.embed_dim}")
        batch_size = 2
        input_tensor = torch.randn(batch_size, 3, 224, 224)
        cls_features = base_model_no_pretrain.encode(input_tensor)
        print(f"✓ CLS features shape: {cls_features.shape}")
        output = base_model_no_pretrain(input_tensor)
        if isinstance(output, tuple):
            cls_logits, dist_logits = output
            print(f"✓ CLS logits shape: {cls_logits.shape}")
            print(f"✓ Distillation logits shape: {dist_logits.shape}")
        else:
            print(f"✓ Output logits shape: {output.shape}")
        
    except Exception as e:
        print(f"❌ Model creation/testing failed: {e}")
        import traceback
        traceback.print_exc()
    
    # Test 3: Model compatibility checking
    print("\n=== Test 3: Model compatibility checking ===")
    print(f"DeiT-Tiny embed dim: {get_model_embed_dim('deit_tiny_patch16_224')}")
    print(f"DeiT-Base embed dim: {get_model_embed_dim('deit_base_patch16_224')}")
    print(f"Tiny->Base compatible: {check_model_compatibility('deit_tiny_patch16_224', 'deit_base_patch16_224')}")
    print(f"Base->Base compatible: {check_model_compatibility('deit_base_patch16_224', 'deit_base_patch16_224')}")
    
    print("\n✓ All tests completed!")