import logging
import torch
import torch.nn as nn
import torchvision
from src.utils.petta_utils import ImageNormalizer


class ResNetDomainNet126(nn.Module):
    """
    Architecture used for DomainNet-126
    - 兼容老版本 torchvision（无 get_model）
    - 支持 optional bottleneck
    - 支持可选 weight_norm
    """
    def __init__(self, arch="resnet50", checkpoint_path=None, num_classes=126, bottleneck_dim=256):
        super().__init__()

        self.arch = arch.lower()
        self.bottleneck_dim = bottleneck_dim
        # 对齐你原有代码：0 表示不用 weight_norm；>0 表示启用并用该维度作为 dim
        self.weight_norm_dim = 0

        # ---- 构建 backbone ----
        if not self.use_bottleneck:
            model = self._build_resnet(self.arch, with_bottleneck=False)
            modules = list(model.children())[:-1]  # 去掉最后的 fc 层
            self.encoder = nn.Sequential(*modules)
            self._output_dim = model.fc.in_features
        else:
            model = self._build_resnet(self.arch, with_bottleneck=True, bottleneck_dim=self.bottleneck_dim)
            # 这里把 resnet 的 fc 当作 bottleneck，再接一个 BN1d
            bn = nn.BatchNorm1d(self.bottleneck_dim)
            # 注意：这里顺序是 [resnet(with fc=bottleneck), bn]
            self.encoder = nn.Sequential(model, bn)
            self._output_dim = self.bottleneck_dim

        # ---- 分类头 ----
        self.fc = nn.Linear(self.output_dim, num_classes)
        if self.use_weight_norm:
            self.fc = nn.utils.weight_norm(self.fc, dim=self.weight_norm_dim)

        # ---- checkpoint 加载（如需）----
        if checkpoint_path:
            self.load_from_checkpoint(checkpoint_path)

        # ---- 输入归一化放到最前面 ----
        self.encoder = nn.Sequential(
            ImageNormalizer((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            self.encoder
        )

    def _build_resnet(self, arch: str, with_bottleneck: bool, bottleneck_dim: int = 256) -> nn.Module:
        """
        返回一个 torchvision 的 resnet 实例，兼容不同版本：
        - 优先尝试 Weights Enum（torchvision>=0.13 常见）
        - 回退到 pretrained=True（更老版本）
        """
        def _resnet_with_weights(arch_name, enum_name):
            # enum 如 ResNet50_Weights
            W = getattr(torchvision.models, enum_name, None)
            if W is not None:
                try:
                    weights = W.IMAGENET1K_V1
                    return getattr(torchvision.models, arch_name)(weights=weights)
                except Exception:
                    pass
            # 回退：老接口
            try:
                return getattr(torchvision.models, arch_name)(pretrained=True)
            except TypeError:
                # 极老版本没有 pretrained 参数，只能默认初始化
                return getattr(torchvision.models, arch_name)()

        if arch == "resnet50":
            model = _resnet_with_weights("resnet50", "ResNet50_Weights")
        elif arch == "resnet18":
            model = _resnet_with_weights("resnet18", "ResNet18_Weights")
        elif arch == "resnet101":
            model = _resnet_with_weights("resnet101", "ResNet101_Weights")
        else:
            raise ValueError(f"Unsupported arch: {arch}")

        if with_bottleneck:
            in_dim = model.fc.in_features
            model.fc = nn.Linear(in_dim, bottleneck_dim)

        return model

    def forward(self, x, return_feats=False):
        # 1) encoder feature
        feat = self.encoder(x)
        feat = torch.flatten(feat, 1)
        # 2) classifier
        logits = self.fc(feat)
        if return_feats:
            return feat, logits
        return logits

    def load_from_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        state_dict = {}
        # 常见两种 key：'state_dict' 或 'model'
        model_state_dict = checkpoint.get("state_dict", checkpoint.get("model", checkpoint))
        for name, param in model_state_dict.items():
            # 去掉 DDP 的 'module.' 前缀
            name = name.replace("module.", "")
            state_dict[name] = param
        msg = self.load_state_dict(state_dict, strict=False)
        logging.info(f"Loaded from {checkpoint_path}; missing params: {msg.missing_keys}")

    def get_params(self):
        """
        Backbone 参数用 1x lr；额外参数（bottleneck + bn + classifier）用 10x lr。
        """
        backbone_params = []
        extra_params = []
        if not self.use_bottleneck:
            # encoder 是去掉 fc 的 resnet 特征提取部分
            backbone_params.extend(self.encoder.parameters())
            # 分类头属于额外参数
            extra_params.extend(self.fc.parameters())
        else:
            # self.encoder = Sequential( ImageNormalizer, Sequential(resnet(with fc=bottleneck), bn) )
            # 取出 resnet 主干与后续的 bn
            # self.encoder[0] 是 ImageNormalizer； self.encoder[1] 是 Sequential(resnet, bn)
            resnet = self.encoder[1][0]  # 第 0 个是 resnet
            bn     = self.encoder[1][1]  # 第 1 个是我们接的 BatchNorm1d

            # resnet 的 children() 去掉最后一层 fc（该 fc 作为 bottleneck 属于 extra）
            for module in list(resnet.children())[:-1]:
                backbone_params.extend(module.parameters())
            # bottleneck fc + (bn) + classifier fc 都算 extra
            extra_params.extend(resnet.fc.parameters())
            extra_params.extend(bn.parameters())
            extra_params.extend(self.fc.parameters())

        # 过滤 requires_grad=False 的参数
        backbone_params = [p for p in backbone_params if p.requires_grad]
        extra_params    = [p for p in extra_params if p.requires_grad]
        return backbone_params, extra_params

    @property
    def num_classes(self):
        return self.fc.weight.shape[0]

    @property
    def output_dim(self):
        return self._output_dim

    @property
    def use_bottleneck(self):
        # >0 才算启用 bottleneck
        return self.bottleneck_dim > 0

    @property
    def use_weight_norm(self):
        # 只有当 weight_norm_dim > 0 时才启用（原来写成 >=0 会一直为 True）
        return self.weight_norm_dim > 0
