import torch.nn as nn
from .cht_layer import Conv2d_CHT, CHTConfig

def replace_conv2cht(module: nn.Module, cht_config: CHTConfig) -> nn.Module:
    """
    递归地将 module 中的所有 nn.Conv2d 替换为 Conv2d_CHT。

    Args:
        module: 任意 nn.Module如(resnet、vgg)的实例
        cht_config: 要传入每个 Conv2d_CHT 的配置

    Returns:
        module: 替换完成后的 module(会在原地修改)
    """
    for name, child in module.named_children():
        # 如果是普通卷积层，就替换
        if isinstance(child, nn.Conv2d):
            # 构造新的 CHT 卷积层，参数一一对应
            assert child.kernel_size[0]==child.kernel_size[1]
            assert child.padding[0]==child.padding[1]
            assert child.stride[0]==child.stride[1]
            new_conv = Conv2d_CHT(
                child.in_channels,
                child.out_channels,
                child.kernel_size[0],   # int
                cht_config,
                padding=child.padding[0], #int
                stride=child.stride[0] #int
            )
            # 拷贝权重和偏置
            #new_conv.weight.data.copy_(child.weight.data)
            #if child.bias is not None and hasattr(new_conv, 'bias') and new_conv.bias is not None:
                #new_conv.bias.data.copy_(child.bias.data)

            # 用新层替换旧层
            setattr(module, name, new_conv)

        else:
            # 递归处理子模块。如果只是普通模块，比如linear,relu,named_children返回空，实际上相当于没处理。
            replace_conv2cht(child, cht_config)

    return module

def replace_linearcht(module: nn.Module, cht_config: CHTConfig) -> nn.Module:
    """
    递归地将 module 中，名字不为 'last_layer' 的所有 nn.Linear 替换为 Linear_CHT。

    Args:
        module: 任意 nn.Module 如 (resnet、vgg) 的实例
        cht_config: 要传入每个 Linear_CHT 的配置

    Returns:
        module: 替换完成后的 module (会在原地修改)
    """
    raise NotImplementedError('Linear CHT is not implemented. Use dst_scheduler.')
    for name, child in module.named_children():
        # 如果是 Linear 并且名字不为 'last_layer'
        if isinstance(child, nn.Linear) and name != 'last_layer':
            # 获取输入、输出维度
            f_in = child.in_features
            f_out = child.out_features
            # 构造新的 Linear_CHT
            new_linear = Linear_CHT(f_in, f_out, cht_config, bias=(child.bias is not None))
            # 将原始权重和偏置复制过来
            #with torch.no_grad():
                #new_linear.weight.copy_(child.weight.data)
                #if child.bias is not None:
                    #new_linear.bias.copy_(child.bias.data)
            # 保留原来的 training/eval 状态
            if not child.training:
                new_linear.eval()
            # 替换模块
            setattr(module, name, new_linear)
        else:
            # 递归替换其子模块
            replace_linearcht(child, cht_config)
    return module

def CHT_evolve(model:nn.Module):
    have_cht=False
    convcht_layers=[]
    for name, child in model.named_children():
        if isinstance(child, Conv2d_CHT):
            have_cht=True
            if child.link_update_ratio!=0:
                child.evolve()
                convcht_layers.append(name)
        else:
            CHT_evolve(child)
    if have_cht:
        print('==========Conv Evolved==========')
        print(convcht_layers)
        
def set_initilized_true_convcht(model):
    for name, child in model.named_children():
        if isinstance(child, Conv2d_CHT):
            child.mask_initialized=True
        else:
            set_initilized_true_convcht(child)