import torch


def save_model(self, dataset="CUB"):
    model_path = f"pu_learner_model_{dataset}.pt"
    save_dict = {
        'prob_pu_learner': self.pu_learner.prob_pu_learner.state_dict(),
        'contrastive_head': self.pu_learner.contrastive_head.state_dict(),
        'classifier': self.pu_learner.classifier.state_dict(),
        'config': {
            'feature_dim': self.pu_learner.feature_dim,
            'lambda_prob_pu': self.pu_learner.lambda_prob_pu,
            'lambda_contrast': self.pu_learner.lambda_contrast,
            'lambda_margin': self.pu_learner.lambda_margin,
            'lambda_reg': self.pu_learner.lambda_reg,
            'margin': self.pu_learner.margin,
            'temp': self.pu_learner.temp
        },
        'att_mat': self.pu_learner.att_mat.cpu()
    }

    torch.save(save_dict, model_path)
    print(f"模型已保存到: {model_path}")


def load_model(self, dataset="CUB"):

    model_path = f"pu_learner_model_{dataset}.pt"
    # 加载模型
    checkpoint = torch.load(model_path, map_location='cpu')

    # 加载模型参数
    self.pu_learner.prob_pu_learner.load_state_dict(checkpoint['prob_pu_learner'])
    self.pu_learner.contrastive_head.load_state_dict(checkpoint['contrastive_head'])
    self.pu_learner.classifier.load_state_dict(checkpoint['classifier'])

    # 将模型移动到指定设备
    self.pu_learner.prob_pu_learner.to(self.device)
    self.pu_learner.contrastive_head.to(self.device)
    self.pu_learner.classifier.to(self.device)

    print(f"模型已从 {model_path} 成功加载")

def save_model_Prob(self, dataset="CUB"):
    model_path = f"saved_model/pu_learner_model_{dataset}_Prob.pt"
    save_dict = {
        'prob_pu_learner': self.pu_learner.prob_pu_learner.state_dict(),
        'config': {
            'feature_dim': self.pu_learner.feature_dim,
            'lambda_prob_pu': self.pu_learner.lambda_prob_pu
        },
        'att_mat': self.pu_learner.att_mat.cpu()
    }

    torch.save(save_dict, model_path)
    print(f"模型已保存到: saved_model/{model_path}")


def load_model_Prob(self, dataset="CUB"):

    model_path = f"saved_model/pu_learner_model_{dataset}_Prob.pt"
    # 加载模型
    checkpoint = torch.load(model_path, map_location='cpu')

    # 加载模型参数
    self.pu_learner.prob_pu_learner.load_state_dict(checkpoint['prob_pu_learner'])

    # 将模型移动到指定设备
    self.pu_learner.prob_pu_learner.to(self.device)

    print(f"模型已从 {model_path} 成功加载")


# def load_model_Prob(self, dataset="CUB"):
#     model_path = f"pu_learner_model_{dataset}_Prob.pt"
#     # 加载模型
#     checkpoint = torch.load(model_path, map_location='cpu')
#
#     # 获取当前模型的状态字典
#     current_state_dict = self.pu_learner.prob_pu_learner.state_dict()
#     saved_state_dict = checkpoint['prob_pu_learner']
#
#     # 过滤出形状匹配的参数
#     filtered_state_dict = {}
#     skipped_keys = []
#
#     for key, param in saved_state_dict.items():
#         if key in current_state_dict:
#             if param.shape == current_state_dict[key].shape:
#                 filtered_state_dict[key] = param
#             else:
#                 skipped_keys.append(f"{key}: saved {param.shape} vs current {current_state_dict[key].shape}")
#         else:
#             skipped_keys.append(f"{key}: not found in current model")
#
#     # 加载过滤后的参数
#     self.pu_learner.prob_pu_learner.load_state_dict(filtered_state_dict, strict=False)
#
#     # 打印跳过的参数信息
#     if skipped_keys:
#         print(f"跳过以下不匹配的参数:")
#         for key in skipped_keys:
#             print(f"  - {key}")
#
#     print(f"成功加载 {len(filtered_state_dict)} 个匹配的参数，跳过 {len(skipped_keys)} 个不匹配的参数")
#
#     # 将模型移动到指定设备
#     self.pu_learner.prob_pu_learner.to(self.device)
#
#     print(f"模型已从 {model_path} 成功加载")


def load_model_Prob_new(self, dataset="CUB"):
    model_path = f"pu_learner_model_{dataset}_Prob.pt"

    try:
        # 加载模型
        checkpoint = torch.load(model_path, map_location='cpu')
        print(f"成功加载checkpoint文件: {model_path}")

        # 首先尝试直接加载
        self.pu_learner.prob_pu_learner.load_state_dict(checkpoint['prob_pu_learner'])
        print("模型参数直接加载成功!")

    except RuntimeError as e:
        print(f"直接加载失败: {e}")
        print("尝试使用键名映射进行灵活加载...")

        # 获取当前模型的state_dict和checkpoint的state_dict
        model_state_dict = self.pu_learner.prob_pu_learner.state_dict()
        checkpoint_state_dict = checkpoint['prob_pu_learner']

        print("=== 调试信息 ===")
        print("当前模型的键:")
        for key in sorted(model_state_dict.keys()):
            print(f"  {key}: {model_state_dict[key].shape}")

        print("\nCheckpoint中的键:")
        for key in sorted(checkpoint_state_dict.keys()):
            if isinstance(checkpoint_state_dict[key], torch.Tensor):
                print(f"  {key}: {checkpoint_state_dict[key].shape}")
            else:
                print(f"  {key}: {type(checkpoint_state_dict[key])}")

        # 创建键名映射字典
        key_mapping = {
            # f2网络层映射
            'f2_hidden.0.weight': 'f2_net.0.weight',
            'f2_hidden.0.bias': 'f2_net.0.bias',
            'f2_hidden.1.weight': 'f2_net.1.weight',
            'f2_hidden.1.bias': 'f2_net.1.bias',
            'f2_hidden.4.weight': 'f2_net.4.weight',
            'f2_hidden.4.bias': 'f2_net.4.bias',
            'f2_hidden.5.weight': 'f2_net.5.weight',
            'f2_hidden.5.bias': 'f2_net.5.bias',
            'f2_hidden.8.weight': 'f2_net.8.weight',
            'f2_hidden.8.bias': 'f2_net.8.bias',
            'f2_hidden.9.weight': 'f2_net.9.weight',
            'f2_hidden.9.bias': 'f2_net.9.bias',
            'f2_output.weight': 'f2_net.11.weight',
            'f2_output.bias': 'f2_net.11.bias',
        }

        # 需要忽略的键（这些键在新模型中不存在）
        ignore_keys = {'neg_proto', 'att_matrix', 'attr_head.weight', 'attr_head.bias'}

        # 转换state_dict
        new_state_dict = {}
        mapped_count = 0
        ignored_count = 0
        direct_match_count = 0

        for old_key, value in checkpoint_state_dict.items():
            if old_key in ignore_keys:
                print(f"忽略键: {old_key}")
                ignored_count += 1
            elif old_key in key_mapping:
                new_key = key_mapping[old_key]
                if new_key in model_state_dict:
                    # 检查形状是否匹配
                    if model_state_dict[new_key].shape == value.shape:
                        new_state_dict[new_key] = value
                        print(f"映射: {old_key} -> {new_key}")
                        mapped_count += 1
                    else:
                        print(
                            f"形状不匹配，跳过: {old_key} {value.shape} -> {new_key} {model_state_dict[new_key].shape}")
                else:
                    print(f"目标键不存在，跳过映射: {old_key} -> {new_key}")
            elif old_key in model_state_dict:
                # 直接匹配的键
                if model_state_dict[old_key].shape == value.shape:
                    new_state_dict[old_key] = value
                    print(f"直接匹配: {old_key}")
                    direct_match_count += 1
                else:
                    print(f"形状不匹配，跳过直接匹配: {old_key} {value.shape} vs {model_state_dict[old_key].shape}")
            else:
                print(f"未知键，跳过: {old_key}")

        print(f"\n=== 加载统计 ===")
        print(f"键名映射加载: {mapped_count}")
        print(f"直接匹配加载: {direct_match_count}")
        print(f"忽略的键: {ignored_count}")
        print(f"总共加载: {len(new_state_dict)} / {len(model_state_dict)} 个参数")

        # 使用转换后的state_dict加载模型
        missing_keys, unexpected_keys = self.pu_learner.prob_pu_learner.load_state_dict(
            new_state_dict, strict=False
        )

        if missing_keys:
            print(f"仍然缺失的键: {missing_keys}")
            print("这些参数将保持随机初始化状态")

        if unexpected_keys:
            print(f"意外的键: {unexpected_keys}")

        print("灵活加载完成!")

    except Exception as e:
        print(f"加载过程中发生错误: {e}")
        print("请检查模型文件是否存在且完整")
        return False

    # 将模型移动到指定设备
    self.pu_learner.prob_pu_learner.to(self.device)
    print(f"模型已从 {model_path} 成功加载并移动到设备: {self.device}")
    return True
