import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Any, Dict, Optional, Tuple, Type, Union
from ._registry import register_model
from ._builder import build_model_with_cfg
from timm.layers import LayerType
__all__ = ['birealnet18', 'birealnet34']


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

def grad_scale(x, scale):
    y = x
    y_grad = x * scale
    return y.detach() - y_grad.detach() + y_grad

class BinaryActivation(nn.Module):
    def __init__(self):
        super(BinaryActivation, self).__init__()

    def forward(self, x):
        out_forward = torch.sign(x)
        #out_e1 = (x^2 + 2*x)
        #out_e2 = (-x^2 + 2*x)
        out_e_total = 0
        mask1 = x < -1
        mask2 = x < 0
        mask3 = x < 1
        out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32))
        out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32))
        out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32))
        out = out_forward.detach() - out3.detach() + out3

        return out


class HardBinaryConv(nn.Module):
    def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1):
        super(HardBinaryConv, self).__init__()
        self.stride = stride
        self.padding = padding
        self.number_of_weights = in_chn * out_chn * kernel_size * kernel_size
        self.shape = (out_chn, in_chn, kernel_size, kernel_size)
        #self.weight = nn.Parameter(torch.rand((self.number_of_weights,1)) * 0.001, requires_grad=True)
        self.weight = nn.Parameter(torch.rand((self.shape)) * 0.001, requires_grad=True)

    def forward(self, x):
        #real_weights = self.weights.view(self.shape)
        real_weights = self.weight
        scaling_factor = torch.mean(torch.mean(torch.mean(abs(real_weights),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True)
        #print(scaling_factor, flush=True)
        scaling_factor = scaling_factor.detach()
        binary_weights_no_grad = scaling_factor * torch.sign(real_weights)
        cliped_weights = torch.clamp(real_weights, -1.0, 1.0)
        binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights
        #print(binary_weights, flush=True)
        y = F.conv2d(x, binary_weights, stride=self.stride, padding=self.padding)

        return y

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()

        self.binary_activation = BinaryActivation()
        self.binary_conv = HardBinaryConv(inplanes, planes, stride=stride)
        self.agent_conv = conv3x3(inplanes, planes, stride=stride)
        self.bn1 = nn.BatchNorm2d(planes)

        self.downsample = downsample
        self.stride = stride

    def copy_binary_to_agent(self):
        """
        在权重加载后，将 binary_conv 的权重复制到 agent_conv。
        """
        with torch.no_grad():
            # 如果 binary_conv 和 agent_conv 的权重形状一致
            if self.agent_conv.weight.shape == self.binary_conv.weight.shape:
                self.agent_conv.weight.copy_(self.binary_conv.weight)
            else:
                raise ValueError("Shape mismatch between binary_conv and agent_conv weights.")

            # # 如果 binary_conv 和 agent_conv 具有偏置项，复制偏置
            # if self.binary_conv.bias is not None and self.agent_conv.bias is not None:
            #     if self.agent_conv.bias.shape == self.binary_conv.bias.shape:
            #         self.agent_conv.bias.copy_(self.binary_conv.bias)
            #     else:
            #         raise ValueError("Shape mismatch between binary_conv and agent_conv biases.")
    
    def forward(self, x):
        residual = x

        out = self.binary_activation(x)
        out = self.binary_conv(out)
        agent_out = self.agent_conv(x)
        g = 10.0 / math.sqrt(self.agent_conv.weight.numel())
        agent_out = grad_scale(agent_out, g)  
        # print("**************")
        # print(g)
        out = out - agent_out.detach() + agent_out
        out = self.bn1(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        return out

class BiRealNet(nn.Module):

    # def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
    def __init__(
            self,
            block,
            layers: Tuple[int, ...],
            num_classes: int = 1000,
            in_chans: int = 3,
            output_stride: int = 32,
            global_pool: str = 'avg',
            cardinality: int = 1,
            base_width: int = 64,
            stem_width: int = 64,
            stem_type: str = '',
            replace_stem_pool: bool = False,
            block_reduce_first: int = 1,
            down_kernel_size: int = 1,
            avg_down: bool = False,
            channels: Optional[Tuple[int, ...]] = (64, 128, 256, 512),
            act_layer: LayerType = nn.ReLU,
            norm_layer: LayerType = nn.BatchNorm2d,
            aa_layer: Optional[Type[nn.Module]] = None,
            drop_rate: float = 0.0,
            drop_path_rate: float = 0.,
            drop_block_rate: float = 0.,
            zero_init_last: bool = True,
            block_args: Optional[Dict[str, Any]] = None,
    ):
        super(BiRealNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.AvgPool2d(kernel_size=2, stride=stride),
                conv1x1(self.inplanes, planes * block.expansion),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def copy_binary_to_agent(self):
        """
        遍历模型中所有 BasicBlock，并调用其 copy_binary_to_agent 方法。
        """
        for layer in [self.layer1, self.layer2, self.layer3, self.layer4]:
            for block in layer:
                if isinstance(block, BasicBlock):
                    block.copy_binary_to_agent()
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x
    
def _create_birealnet(variant, pretrained: bool = False, **kwargs) -> BiRealNet:
    return build_model_with_cfg(BiRealNet, variant, pretrained, **kwargs)

# @register_model
# def birealnet18(pretrained=False, **kwargs):
#     """Constructs a BiRealNet-18 model. """
#     model = BiRealNet(BasicBlock, [4, 4, 4, 4], **kwargs)
#     return model
@register_model
def birealnet18(pretrained: bool = False, **kwargs) -> BiRealNet:
    """Constructs a ResNet-10-T model.
    """
    model_args = dict(block=BasicBlock, layers=(4, 4, 4, 4))
    return _create_birealnet('birealnet18', pretrained, **dict(model_args, **kwargs))

@register_model
def birealnet34(pretrained=False, **kwargs):
    """Constructs a BiRealNet-34 model. """
    model = BiRealNet(BasicBlock, [6, 8, 12, 6], **kwargs)
    return model

