import torch
from typing import Dict, Any
from  models.VGG import VGG16_optimalThres
from models.CHT_Model import VGG16_CHT_optimalThres

def vgg16_to_thre(src_state_dict: Dict[str, torch.Tensor],
                 target_model: torch.nn.Module,
                 verbose: bool = True) -> Dict[str, Any]:
    """
    把 src_state_dict 的参数映射到 target_model（逐层属性版本）上并加载。
    匹配策略：按 target_model.state_dict() 的 key 顺序，逐个在 src_state_dict 的 key 列表中寻找第一个 shape 完全相同的 tensor 来映射。
    返回 load 的结果 dict（load_state_dict 返回值）。
    说明：函数不会修改源字典；会打印映射情况以便调试。
    """
    src_keys = list(src_state_dict.keys())
    tgt_state = target_model.state_dict()
    tgt_keys = list(tgt_state.keys())

    if verbose:
        print(f"[vgg16_to_thre] src has {len(src_keys)} keys, target expects {len(tgt_keys)} keys")

    used_src = set()
    new_state = {}
    i_src = 0

    # 为了兼容 src_state_dict 可能是 OrderedDict 或普通 dict，复制一次 src items
    src_items = [(k, src_state_dict[k]) for k in src_keys]

    for tgt_key in tgt_keys:
        tgt_shape = tgt_state[tgt_key].shape
        found = False
        # 从当前 i_src 开始向后寻找第一个 shape 匹配的 src 参数（且还没被使用）
        while i_src < len(src_items):
            src_key, src_tensor = src_items[i_src]
            i_src += 1
            # 跳过已经被用过的 key（理论上不必要，因为按顺序）
            if src_key in used_src:
                continue
            if src_tensor.shape == tgt_shape:
                # 形状匹配，映射
                new_state[tgt_key] = src_tensor.clone().contiguous()
                used_src.add(src_key)
                found = True
                if verbose:
                    print(f"[mapped] {src_key} --> {tgt_key}  shape {tgt_shape}")
                break
            else:
                if verbose:
                    print(f"[skip]  {src_key} (shape={src_tensor.shape})  !=  {tgt_key} (shape={tgt_shape})")
                # 继续寻找
        if not found:
            # 在剩余 src 中仍可能存在与 tgt_shape 相同的项 —— 我们再做一次全表搜索作为 fallback（但不包括已用过的）
            for src_key, src_tensor in src_items:
                if src_key in used_src:
                    continue
                if src_tensor.shape == tgt_shape:
                    new_state[tgt_key] = src_tensor.clone().contiguous()
                    used_src.add(src_key)
                    found = True
                    if verbose:
                        print(f"[mapped-fallback] {src_key} --> {tgt_key}  shape {tgt_shape}")
                    break

        if not found:
            # 无匹配，报警（继续但该 tgt_key 会缺失）
            if verbose:
                print(f"[WARN] 没有找到匹配的源参数来填充 target {tgt_key} (shape={tgt_shape})")

    # 将 new_state 的 keys 与 target 的 state_dict 合并（对于没有匹配到的键保留 target 自身的初始化值）
    merged = tgt_state.copy()
    merged.update(new_state)

    # 加载
    load_res = target_model.load_state_dict(merged, strict=True)
    if verbose:
        print(f"[load_state_dict] missing_keys={load_res.missing_keys}, unexpected_keys={load_res.unexpected_keys}")
        if len(used_src) < len(src_keys):
            unused = [k for k,_ in src_items if k not in used_src]
            print(f"[INFO] {len(unused)} source keys were NOT used (examples): {unused[:10]}")
    return load_res

def _extract_state_dict(src: Dict[str, Any]) -> Dict[str, torch.Tensor]:
    """如果传入 checkpoint（带 'state_dict'），返回真实 state_dict，否则原样返回。"""
    if not isinstance(src, dict):
        raise TypeError("src_state_dict must be a dict-like object.")
    if "state_dict" in src and isinstance(src["state_dict"], dict):
        return src["state_dict"]
    return src

def vgg16_Hanming_to_CHT_thre(src_state_dict: Dict[str, torch.Tensor],
                             target_model: torch.nn.Module,
                             verbose: bool = True) -> Dict[str, Any]:
    """
    将 VGG16_CIFAR_BN (features=nn.Sequential with Conv2d_CHT) 的 state_dict 映射并加载到
    VGG16_CHT_optimalThres 的实例上（每层为属性的版本，Conv2d_CHT 也被使用）。

    策略: 对 target 的每个 key 按顺序，在 src 中寻找第一个 shape 完全匹配且未被使用的 tensor 来填充。
    返回: load_state_dict 返回的命名空间（missing_keys / unexpected_keys）。
    同时返回 mapping (src_key -> tgt_key) 会被打印（和可选地返回/保存）。
    """
    src = _extract_state_dict(src_state_dict)
    src_keys = list(src.keys())
    src_items = [(k, src[k]) for k in src_keys]

    target_state = target_model.state_dict()
    tgt_keys = list(target_state.keys())

    if verbose:
        print(f"[vgg16_Hanming_to_CHT_thre] src keys: {len(src_keys)}, target keys: {len(tgt_keys)}")

    used_src = set()
    mapping = {}          # src_key -> tgt_key
    new_state = {}        # tgt_key -> tensor (to load)

    # iterate through target keys in order
    for tgt_key in tgt_keys:
        tgt_tensor = target_state[tgt_key]
        tgt_shape = tuple(tgt_tensor.shape)
        found = False

        # first pass: try continuing from last unused src position for speed (sequential alignment)
        for src_key, src_tensor in src_items:
            if src_key in used_src:
                continue
            if tuple(src_tensor.shape) == tgt_shape:
                # map and mark used
                # ensure tensor on same dtype/device as target (move to cpu by default then let load handle device)
                # we'll move to target device to be safe
                try:
                    mapped = src_tensor.to(tgt_tensor.device)
                except Exception:
                    mapped = src_tensor.clone()
                new_state[tgt_key] = mapped.clone().contiguous()
                mapping[src_key] = tgt_key
                used_src.add(src_key)
                found = True
                if verbose:
                    print(f"[mapped] {src_key} -> {tgt_key}  shape={tgt_shape}")
                break

        # fallback: full scan again (should be rarely necessary)
        if not found:
            for src_key, src_tensor in src_items:
                if src_key in used_src:
                    continue
                if tuple(src_tensor.shape) == tgt_shape:
                    try:
                        mapped = src_tensor.to(tgt_tensor.device)
                    except Exception:
                        mapped = src_tensor.clone()
                    new_state[tgt_key] = mapped.clone().contiguous()
                    mapping[src_key] = tgt_key
                    used_src.add(src_key)
                    found = True
                    if verbose:
                        print(f"[mapped-fallback] {src_key} -> {tgt_key}  shape={tgt_shape}")
                    break

        if not found:
            if verbose:
                print(f"[WARN] no matching src param for target '{tgt_key}'  expected shape={tgt_shape}")

    # 合并：用 new_state 覆盖 target_state 的对应键（未映射的保留 target 的初始化值）
    merged_state = target_state.copy()
    merged_state.update(new_state)

    # 最后加载到模型
    load_res = target_model.load_state_dict(merged_state, strict=True)


    print(f"[load_state_dict] missing_keys ({len(load_res.missing_keys)}): {load_res.missing_keys}")
    print(f"[load_state_dict] unexpected_keys ({len(load_res.unexpected_keys)}): {load_res.unexpected_keys}")
    unused_src = [k for k, _ in src_items if k not in used_src]
    print(f"[INFO] used {len(used_src)} / {len(src_items)} source keys. unused_examples: {unused_src[:10]}")

    # 返回加载结果以及映射表，便于后续保存或检查
    return {
        "load_result": load_res,
        "mapping": mapping,
        "unused_src_keys": unused_src
    }


if __name__=='__main__':
    #测试代码
    src = torch.load("../input/VGG-16/CIFAR10/conv_0.0/s_0.0/d_0.0/onefc_True/lr_0.1/bs_32/best_model.pth", map_location="cpu")  # 你的旧 vgg16 的 state_dict
    model_new = VGG16_optimalThres(num_classes=10, one_fc=True)         # 目标实例
    result = vgg16_to_thre(src, model_new, verbose=True)
    print('dense',result)

    src = torch.load("../input/VGG-16/CIFAR10/conv_0.5/s_0.0/d_0.0/onefc_True/lr_0.1/bs_32/best_model.pth", map_location="cpu")  # 你的旧 vgg16 的 state_dict
    model_new = VGG16_CHT_optimalThres(num_classes=10, one_fc=True)         # 目标实例
    result = vgg16_Hanming_to_CHT_thre(src, model_new, verbose=True)
    print('dense',result)
