import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
import timm
import torch.nn.functional as F
from kan import KAN
PYKAN_AVAILABLE = True
import timm

os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import segmentation_models_pytorch as smp


class FastKANHead(nn.Module):
    def __init__(self, in_channels, out_channels, bottleneck_channels=8, kan_grid=5, kan_k=3):
        super().__init__()
        self.squeeze = nn.Sequential(
            nn.Conv2d(in_channels, bottleneck_channels, 1, bias=False), nn.InstanceNorm2d(bottleneck_channels), nn.ReLU(inplace=True),
        )
        if PYKAN_AVAILABLE: self.kan = KAN(width=[bottleneck_channels, bottleneck_channels, bottleneck_channels], grid=kan_grid, k=kan_k)
        else: self.kan = nn.Conv2d(bottleneck_channels, bottleneck_channels, 3, 1, 1)
        self.expand = nn.Conv2d(bottleneck_channels, out_channels, 1)
    def forward(self, features):
        b_feat = self.squeeze(features)
        b, c, h, w = b_feat.shape
        kan_in = b_feat.permute(0, 2, 3, 1).reshape(-1, c)
        kan_out = self.kan(kan_in)
        kan_feat = kan_out.view(b, h, w, c).permute(0, 3, 1, 2)
        return self.expand(kan_feat)

class BaseKeypointModel(nn.Module):
    def __init__(self, model, heatmap_size=128, task_type="keypoint"):
        super().__init__()
        self.model = model; self.heatmap_size = heatmap_size; self.task_type = task_type
    def forward(self, x):
        raw_output = self.model(x)
        if self.task_type == 'keypoint' and raw_output.shape[-2:] != (self.heatmap_size, self.heatmap_size):
            return F.interpolate(raw_output, size=(self.heatmap_size, self.heatmap_size), mode='bilinear', align_corners=False)
        return raw_output

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), nn.InstanceNorm2d(out_channels), nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), nn.InstanceNorm2d(out_channels), nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.conv(x)

class ConvBlockWithDropout(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_p=0.1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), nn.InstanceNorm2d(out_channels), nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout_p),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), nn.InstanceNorm2d(out_channels), nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout_p)
        )
    def forward(self, x): return self.conv(x)

class UpBlock(nn.Module):
    def __init__(self, in_channels_deep, in_channels_skip, out_channels, use_dropout=False, dropout_p=0.1):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels_deep, in_channels_deep // 2, kernel_size=2, stride=2)
        
        # 如果 in_channels_skip 是 0，则不进行拼接
        conv_in_channels = (in_channels_deep // 2) + in_channels_skip
        
        ConvBlock_class = ConvBlockWithDropout if use_dropout else ConvBlock
        self.conv = ConvBlock_class(conv_in_channels, out_channels, dropout_p=dropout_p) if use_dropout else ConvBlock_class(conv_in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # 只有当x2 (跳跃连接) 存在时才进行填充和拼接
        if x2 is not None:
            diff_y = x2.size(2) - x1.size(2)
            diff_x = x2.size(3) - x1.size(3)
            x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2, diff_y // 2, diff_y - diff_y // 2])
            final_input = torch.cat([x2, x1], dim=1)
        else:
            final_input = x1
            
        return self.conv(final_input)

class ViTBottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, feature_map_size=(16, 16), 
                 vit_model_name='vit_tiny_patch16_224', vit_pretrained_path=None):
        super().__init__()
        if timm is None: raise ImportError("'timm' library is required.")
        self.vit = timm.create_model(vit_model_name, pretrained=False); vit_embed_dim = self.vit.embed_dim
        if vit_pretrained_path and os.path.exists(vit_pretrained_path):
            print(f"--- Loading ViT weights from: {vit_pretrained_path} ---")
            self.vit.load_state_dict(torch.load(vit_pretrained_path, map_location='cpu'), strict=False)
        original_pos_embed = self.vit.pos_embed; cls_pos_embed = original_pos_embed[:, :1]
        patch_pos_embed = original_pos_embed[:, 1:].transpose(1, 2).reshape(1, vit_embed_dim, 14, 14)
        new_patch_pos_embed = F.interpolate(patch_pos_embed, size=feature_map_size, mode='bilinear', align_corners=False)
        new_pos_embed = torch.cat([cls_pos_embed, new_patch_pos_embed.flatten(2).transpose(1, 2)], dim=1)
        self.vit.pos_embed = nn.Parameter(new_pos_embed)
        self.cnn_to_vit_proj = nn.Conv2d(in_channels, vit_embed_dim, 1)
        self.vit_to_cnn_proj = nn.Conv2d(vit_embed_dim, out_channels, 1)
        self.norm = nn.LayerNorm(vit_embed_dim)
    def forward(self, x):
        x_proj = self.cnn_to_vit_proj(x); b, c, h, w = x_proj.shape
        x_seq = x_proj.flatten(2).transpose(1, 2)
        cls_token = self.vit.cls_token.expand(b, -1, -1)
        x_seq = torch.cat((cls_token, x_seq), dim=1) + self.vit.pos_embed
        x_vit_out = self.norm(self.vit.blocks(x_seq))[:, 1:, :]
        return self.vit_to_cnn_proj(x_vit_out.transpose(1, 2).reshape(b, c, h, w))


class CustomUnet(nn.Module):
    def __init__(self, encoder_name, n_classes=3, use_kan_head=False, 
                 use_hierarchical_head=False, use_vit_bottleneck=False, vit_pretrained_path=None,
                 use_dropout=False, dropout_p=0.1): 
        super().__init__()
        self.use_hierarchical_head = use_hierarchical_head
        
        self.encoder = smp.encoders.get_encoder(name=encoder_name, in_channels=3, weights="imagenet")
        e_channels = self.encoder.out_channels
        skip_channels = e_channels[1:-1][::-1]; deepest_channels = e_channels[-1]
        decoder_channels = [256, 128, 64, 32]

        if use_vit_bottleneck:
            bottleneck_size = (16, 16) 
            self.bottleneck = ViTBottleneck(deepest_channels, 512, feature_map_size=bottleneck_size, vit_pretrained_path=vit_pretrained_path)
        else:
            self.bottleneck = ConvBlock(deepest_channels, 512)

        self.up1 = UpBlock(512, skip_channels[0], decoder_channels[0], use_dropout, dropout_p)
        self.up2 = UpBlock(decoder_channels[0], skip_channels[1], decoder_channels[1], use_dropout, dropout_p)
        self.up3 = UpBlock(decoder_channels[1], skip_channels[2], decoder_channels[2], use_dropout, dropout_p)
        self.up4 = UpBlock(decoder_channels[2], skip_channels[3], decoder_channels[3], use_dropout, dropout_p)
        
        final_in_channels = decoder_channels[3]
        
        self.pre_head_dropout = nn.Dropout2d(p=dropout_p) if use_dropout else nn.Identity()
        
        if self.use_hierarchical_head:
            HeadClass = FastKANHead if use_kan_head and PYKAN_AVAILABLE else nn.Conv2d
            self.ps_head = HeadClass(final_in_channels, 2) if use_kan_head else nn.Conv2d(final_in_channels, 2, 1)
            self.fh_head_fusion_block = ConvBlock(final_in_channels + 2, final_in_channels)
            self.fh_predictor = HeadClass(final_in_channels, 1) if use_kan_head else nn.Conv2d(final_in_channels, 1, 1)
        else:
            if use_kan_head and PYKAN_AVAILABLE:
                self.head = FastKANHead(in_channels=final_in_channels, out_channels=n_classes)
            else:
                self.head = nn.Conv2d(final_in_channels, n_classes, kernel_size=1)

    def forward(self, x):
        features = self.encoder(x); skips = features[1:-1][::-1]; deepest = features[-1]
        b = self.bottleneck(deepest)
        d1 = self.up1(b, skips[0]); d2 = self.up2(d1, skips[1])
        d3 = self.up3(d2, skips[2]); d4 = self.up4(d3, skips[3])
        
        d4 = self.pre_head_dropout(d4)
        
        if self.use_hierarchical_head:
            heatmap_ps = self.ps_head(d4)
            fh_head_input = torch.cat([d4, heatmap_ps.detach()], dim=1)
            fused_features = self.fh_head_fusion_block(fh_head_input)
            heatmap_fh = self.fh_predictor(fused_features)
            return torch.cat([heatmap_ps, heatmap_fh], dim=1)
        else:
            return self.head(d4)

class HRNetForHeatmaps(nn.Module):
    def __init__(self, encoder_name='hrnet_w32', n_classes=3, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(
            encoder_name,
            pretrained=pretrained,
            in_chans=3,
            features_only=True,
        )
        feature_channels = self.backbone.feature_info.channels()
        num_fused_features = sum(feature_channels)

        refinement_channels = 256 
        self.refinement_head = nn.Sequential(
            nn.Conv2d(in_channels=num_fused_features, out_channels=refinement_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(refinement_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=refinement_channels, out_channels=refinement_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(refinement_channels),
            nn.ReLU(inplace=True),
        )

        self.final_head = nn.Conv2d(
            in_channels=refinement_channels,
            out_channels=n_classes,
            kernel_size=1
        )

    def forward(self, x):
        features = self.backbone(x)
        high_res_feature = features[0]
        target_size = high_res_feature.shape[2:]
        
        upsampled_features = [high_res_feature]
        for i in range(1, len(features)):
            upsampled = F.interpolate(features[i], size=target_size, mode='bilinear', align_corners=False)
            upsampled_features.append(upsampled)
        
        fused_features = torch.cat(upsampled_features, dim=1)
        
        refined_features = self.refinement_head(fused_features)
        
        heatmaps_logits = self.final_head(refined_features)
        
        return heatmaps_logits

class TransUNet(nn.Module):
    def __init__(self, n_classes=3, cnn_encoder_name='resnet50', vit_model_name='vit_base_patch16_224_in21k', pretrained=True):
        super().__init__()
        
        # --- 1. CNN 混合编码器 ---
        self.cnn_encoder = timm.create_model(
            cnn_encoder_name,
            pretrained=pretrained,
            features_only=True,
            out_indices=(1, 2, 3, 4) # 需要4个阶段的输出作为跳跃连接
        )
        cnn_feature_channels = self.cnn_encoder.feature_info.channels()
        
        # --- 2. Transformer 瓶颈层 ---
        # 创建一个标准的ViT模型
        self.vit = timm.create_model(
            vit_model_name,
            pretrained=pretrained,
            num_classes=0  # 我们不需要分类头
        )
        
        # ViT的输入是CNN最深层特征图展平后的序列
        self.vit_proj = nn.Conv2d(
            in_channels=cnn_feature_channels[-1], # ResNet50最后一个stage的通道数 (2048)
            out_channels=self.vit.embed_dim,      # ViT-Base的嵌入维度 (768)
            kernel_size=1
        )
        
        # --- 3. 级联升采样解码器 ---
        decoder_channels = [256, 128, 64, 32]
        
        # ViT输出的特征需要先通过一个卷积块
        self.decoder_bottleneck = ConvBlock(self.vit.embed_dim, 512)
        
        # 创建一系列的上采样块，融合来自CNN的跳跃连接
        self.up1 = UpBlock(512, cnn_feature_channels[2], decoder_channels[0]) # 1024 -> 256
        self.up2 = UpBlock(decoder_channels[0], cnn_feature_channels[1], decoder_channels[1]) # 512 -> 128
        self.up3 = UpBlock(decoder_channels[1], cnn_feature_channels[0], decoder_channels[2]) # 256 -> 64
        self.up4 = UpBlock(decoder_channels[2], 0, decoder_channels[3]) # 最后一个上采样，不使用跳跃连接
        
        # --- 4. 最终输出头 ---
        self.head = nn.Conv2d(decoder_channels[3], n_classes, kernel_size=1)
        self.final_activation = nn.Sigmoid()

    def forward(self, x):
        # 1. 通过CNN编码器获取跳跃连接特征
        skips = self.cnn_encoder(x)
        cnn_deepest_features = skips[-1]
        
        # 2. 将CNN最深层特征输入Transformer瓶颈
        vit_input = self.vit_proj(cnn_deepest_features) # (B, 768, H/16, W/16)
        
        # 展平特征图以适应ViT的输入格式 (B, N, D)
        b, c, h, w = vit_input.shape
        vit_input = vit_input.flatten(2).transpose(1, 2) # (B, H*W, 768)
        
        # ViT前向传播
        vit_output = self.vit.forward_features(vit_input) # (B, H*W, 768)
        
        # 将ViT输出 reshape 回 2D 图像格式
        bottleneck_features = vit_output.transpose(1, 2).reshape(b, c, h, w) # (B, 768, H/16, W/16)

        # 3. 通过解码器进行上采样
        d0 = self.decoder_bottleneck(bottleneck_features) # (B, 512, H/16, W/16)
        d1 = self.up1(d0, skips[2])      # 使用倒数第二个skip
        d2 = self.up2(d1, skips[1])      # 使用倒数第三个skip
        d3 = self.up3(d2, skips[0])      # 使用倒数第四个skip
        d4 = self.up4(d3, None)          # 最后一个UpBlock的跳跃连接输入设为None或一个空的Tensor

        # 4. 生成最终热力图
        logits = self.head(d4)
        heatmaps = self.final_activation(logits)
        
        return heatmaps

class SwinUnet(nn.Module):
    def __init__(self, n_classes=3, encoder_name='swin_tiny_patch4_window7_224', pretrained=True):
        super().__init__()
        
        self.encoder = timm.create_model(
            encoder_name,
            pretrained=pretrained,
            features_only=True,
            img_size=(512, 512)
        )
        
        encoder_channels = self.encoder.feature_info.channels()
        bottleneck_channels = encoder_channels[-1]
        skip_channels = encoder_channels[:-1][::-1]
        
        self.bottleneck = ConvBlock(bottleneck_channels, bottleneck_channels)
        
        decoder_channels = [256, 128, 64, 32]
        self.up1 = UpBlock(bottleneck_channels, skip_channels[0], decoder_channels[0])
        self.up2 = UpBlock(decoder_channels[0], skip_channels[1], decoder_channels[1])
        self.up3 = UpBlock(decoder_channels[1], skip_channels[2], decoder_channels[2])
        self.up4 = UpBlock(decoder_channels[2], 0, decoder_channels[3])
        
        self.head = nn.Conv2d(decoder_channels[3], n_classes, kernel_size=1)
        self.final_activation = nn.Sigmoid()

    def forward(self, x):
        features = self.encoder(x)
        
        features_channels_first = []
        for f in features:
            features_channels_first.append(f.permute(0, 3, 1, 2))
        
        bottleneck = features_channels_first[-1]
        skips = features_channels_first[:-1][::-1]
        
        b = self.bottleneck(bottleneck)
        
        d1 = self.up1(b, skips[0])
        d2 = self.up2(d1, skips[1])
        d3 = self.up3(d2, skips[2])
        d4 = self.up4(d3, None)
        
        logits = self.head(d4)
        heatmaps = self.final_activation(logits)
        
        return heatmaps

class ConvNeXtUnet(nn.Module):
    def __init__(self, n_classes=3, encoder_name='convnext_tiny', pretrained=True):
        super().__init__()
        
        # --- 1. ConvNeXt 编码器 ---
        self.encoder = timm.create_model(
            encoder_name,
            pretrained=pretrained,
            features_only=True,
            # ConvNeXt通常输出4个stage的特征
        )
        
        # 动态地从encoder获取真实的输出通道数
        encoder_channels = self.encoder.feature_info.channels()
        bottleneck_channels = encoder_channels[-1]
        skip_channels = encoder_channels[:-1][::-1]
        
        # --- 2. 瓶颈层 (Bottleneck) ---
        self.bottleneck = ConvBlock(bottleneck_channels, bottleneck_channels)
        
        # --- 3. 级联升采样解码器 ---
        decoder_channels = [256, 128, 64, 32]
        
        self.up1 = UpBlock(bottleneck_channels, skip_channels[0], decoder_channels[0])
        self.up2 = UpBlock(decoder_channels[0], skip_channels[1], decoder_channels[1])
        self.up3 = UpBlock(decoder_channels[1], skip_channels[2], decoder_channels[2])
        self.up4 = UpBlock(decoder_channels[2], 0, decoder_channels[3])
        
        # --- 4. 最终输出头 ---
        self.head = nn.Conv2d(decoder_channels[3], n_classes, kernel_size=1)
        self.final_activation = nn.Sigmoid()

    def forward(self, x):
        features = self.encoder(x)
        
        bottleneck = features[-1]
        skips = features[:-1][::-1]
        
        b = self.bottleneck(bottleneck)
        
        d1 = self.up1(b, skips[0])
        d2 = self.up2(d1, skips[1])
        d3 = self.up3(d2, skips[2])
        d4 = self.up4(d3, None)
        
        logits = self.head(d4)
        heatmaps = self.final_activation(logits)
        
        return heatmaps

def get_model(encoder_name="resnet34", task_type='keypoint', heatmap_size=128, classes_num=3, 
              use_kan=False, use_hierarchical_head=False, 
              use_vit_bottleneck=False, vit_pretrained_path=None,
              use_dropout=False, dropout_p=0.1):
    
    if 'hrnet' in encoder_name.lower():
        print(f"--- Creating HRNet model: {encoder_name} ---")
        core_model = HRNetForHeatmaps(encoder_name=encoder_name, n_classes=classes_num)
    
    elif 'transunet' in encoder_name.lower():
        print(f"--- Creating TransUNet model: {encoder_name} ---")
        core_model = TransUNet(n_classes=classes_num) 

    elif 'swin' in encoder_name.lower():
        print(f"--- Creating SwinUnet model with encoder: {encoder_name} ---")
        core_model = SwinUnet(n_classes=classes_num, encoder_name=encoder_name)
        
    elif 'convnext' in encoder_name.lower():
        print(f"--- Creating ConvNeXtUnet model with encoder: {encoder_name} ---")
        core_model = ConvNeXtUnet(n_classes=classes_num, encoder_name=encoder_name)
        
    else: # 默认为 U-Net
        print(f"--- Creating U-Net model with encoder: {encoder_name} ---")
        core_model = CustomUnet(
            encoder_name=encoder_name, n_classes=classes_num,
            use_kan_head=use_kan, use_hierarchical_head=use_hierarchical_head,
            use_vit_bottleneck=use_vit_bottleneck, vit_pretrained_path=vit_pretrained_path,
            use_dropout=use_dropout, dropout_p=dropout_p
        )
    
    final_model = BaseKeypointModel(model=core_model, heatmap_size=heatmap_size, task_type=task_type)
    return final_model


def replace_bn_with_in(module):
    """
    递归地将一个PyTorch模块中的所有BatchNorm2d层替换为InstanceNorm2d层。
    """
    for name, child in list(module.named_children()):
        if isinstance(child, nn.BatchNorm2d):
            num_features = child.num_features
            new_layer = nn.InstanceNorm2d(num_features, affine=True)
            setattr(module, name, new_layer)
        else:
            replace_bn_with_in(child)

#==============================================================================
# 5. 测试代码 (用于验证)
#==============================================================================
# if __name__ == '__main__':
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     dummy_input = torch.randn(2, 3, 512, 512).to(device)

#     print("\n" + "="*80)
#     print("--- Testing Scenario 1: Standard Single-Head Unet ---")
#     model_1 = get_model(encoder_name="resnet34", use_hierarchical_head=False, use_kan=False).to(device)
#     output_1 = model_1(dummy_input)
#     assert output_1.shape == (2, 3, 128, 128)
#     print("Test PASSED.")

#     print("\n" + "="*80)
#     print("--- Testing Scenario 2: Standard Multi-Head Unet ---")
#     model_2 = get_model(encoder_name="resnet34", use_hierarchical_head=True, use_kan=False).to(device)
#     output_2 = model_2(dummy_input)
#     assert output_2.shape == (2, 3, 128, 128)
#     print("Test PASSED.")

#     if PYKAN_AVAILABLE:
#         print("\n" + "="*80)
#         print("--- Testing Scenario 3: FastKAN Multi-Head Unet ---")
#         model_3 = get_model(encoder_name="resnet34", use_hierarchical_head=True, use_kan=True).to(device)
#         output_3 = model_3(dummy_input)
#         assert output_3.shape == (2, 3, 128, 128)
#         print("Test PASSED.")
    
#     print("\n" + "="*80)
#     print("The final model factory is ready for high-performance experiments!")