import torch
import torch.nn as nn
from functools import partial
# from model import get_model

import timm
# from timm.layers.adaptive_avgmax_pool import FastAdaptiveAvgPool

WEIGHT_PATHS = {
       'resnet18': './checkpoints/resnet18-5c106cde.pth',
       'resnet50': './checkpoints/resnet50_a1_0-14fe96d1.pth',
       'resnet101': './checkpoints/resnet101_a1h-36d3f2aa.pth',       
       'densenet121': './checkpoints/densenet121_ra-50efcf5c.pth'  # DenseNet-121
}

class CustomResNetModel(nn.Module):
    def __init__(self,model_name, num_classes, pretrained=True):
        super().__init__()
        
        self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
        
        def avg_pooling_hookfn(module, input, output):
            """
            input[0].shape: (b,2048,7,7)
            output.shape: (b,2048)
            """
            self.b, c, self.h, self.w = input[0].shape # (b,c,h,w)->(64,2048,7,7)
            dense_output = input[0].permute(0, 2, 3, 1).reshape(-1, c) # (b*h*w, c)  
            concat_output = torch.cat((output, dense_output), dim=0) # (b*h*w + b, c)

            return concat_output
            # return output
        
        if hasattr(self.model, 'global_pool'):
            self.model.global_pool.register_forward_hook(avg_pooling_hookfn)
        else:
            # 如果模型结构中不是直接暴露global_pool，可能需要更深入地检查模型的子模块
            for name, module in self.model.named_modules():
                if isinstance(module, torch.nn.AdaptiveAvgPool2d):
                    module.register_forward_hook(avg_pooling_hookfn)
                    break

    def forward(self, x):
        output = self.model(x) #: (b + b*h*w, c)
        return output[:self.b, :], output[self.b:, :].reshape(self.b, self.h, self.w, -1).permute(0, 3, 1, 2) 
        

class CustomInceptionV3Model(nn.Module):
    def __init__(self,model_name, num_classes, pretrained=True):
        super().__init__()
        
        self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
        

        def avg_pooling_hookfn(module, input, output):
            """
            input[0].shape: (b,2048,5,5)
            output.shape: (b,2048,1,1)
            """
            
            self.b, c, self.h, self.w = input[0].shape # (b,c,h,w)->(64,2048,5,5)
            dense_output = input[0].permute(0, 2, 3, 1).reshape(-1, c, 1,1) # (b*h*w, c)
            concat_output = torch.cat((output, dense_output), dim=0) # (b*h*w + b, c)
            return concat_output

        
        for name, module in self.model.named_modules():
            if isinstance(module, torch.nn.AdaptiveAvgPool2d):
                module.register_forward_hook(avg_pooling_hookfn)
                break

    def forward(self, x):
        output = self.model(x) #: (b + b*h*w, c)
        return output[:self.b, :], output[self.b:, :].reshape(self.b, self.h, self.w, -1).permute(0, 3, 1, 2) 

class CustomSwinModel(nn.Module):
    def __init__(self,model_name, num_classes, pretrained=True):
        super().__init__()
        
        self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
        

        def avg_pooling_hookfn(module, input, output):
            """
            input[0].shape: (b,768,49)
            output.shape: (b,768,1)
            """
            
            self.b, c, self.n = input[0].shape 
            dense_output = input[0].permute(0, 2, 1).reshape(-1, c, 1) # (b*h*w, c)
            #: 
            
            concat_output = torch.cat((output, dense_output), dim=0) # (b*h*w + b, c)
            return concat_output
        
        for name, module in self.model.named_modules():
            if name == 'avgpool':  # 
                module.register_forward_hook(avg_pooling_hookfn)
                break

    def forward(self, x):
        output = self.model(x) #: (b + b*h*w, c)
        return output[:self.b, :], output[self.b:, :].reshape(self.b, self.n, -1) 
    
class CustomViTModel(nn.Module):
    def __init__(self,model_name, num_classes, pretrained=True):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
        

    def forward_features(self, x):
        x = self.model.patch_embed(x)
        cls_token = self.model.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_token, x), dim=1)

        x = self.model.pos_drop(x + self.model.pos_embed)
        x = self.model.blocks(x)
        x = self.model.norm(x) # (b, N, c)

        return x[:, 0], x[:, 1:]

    def forward(self, x):
        x, x_dense = self.forward_features(x)
        x = self.model.head(x) # cls_token
        N, M, C = x_dense.shape
        x_dense = x_dense.reshape(-1, C)
        x_dense = self.model.head(x_dense).reshape(N, M, -1) # patch_tokens

        return x, x_dense #: cls_token, patch_token
    
