import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append('..')
from CHT import hanming_config, Conv2d_CHT



class VGG16_CHT_optimalThres(nn.Module):
    """
    VGG16 for CIFAR, attribute-per-layer style, using Conv2d_CHT for all convs,
    MaxPool2d for pooling, kernel_size=3, padding=1 everywhere (as requested).
    """
    def __init__(self, num_classes: int,one_fc ,cht_config=hanming_config):
        super().__init__()

        # ---- GROUP 1 ----
        self.conv1_1 = Conv2d_CHT(3, 64, 3, cht_config, padding=1)
        self.BN1_1 = nn.BatchNorm2d(64)

        self.conv1_2 = Conv2d_CHT(64, 64, 3, cht_config, padding=1)
        self.BN1_2 = nn.BatchNorm2d(64)

        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        # ---- GROUP 2 ----
        self.conv2_1 = Conv2d_CHT(64, 128, 3, cht_config, padding=1)
        self.BN2_1 = nn.BatchNorm2d(128)

        self.conv2_2 = Conv2d_CHT(128, 128, 3, cht_config, padding=1)
        self.BN2_2 = nn.BatchNorm2d(128)

        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # ---- GROUP 3 ----
        self.conv3_1 = Conv2d_CHT(128, 256, 3, cht_config, padding=1)
        self.BN3_1 = nn.BatchNorm2d(256)

        self.conv3_2 = Conv2d_CHT(256, 256, 3, cht_config, padding=1)
        self.BN3_2 = nn.BatchNorm2d(256)

        # conv3_3 保持 3x3 (padding=1)
        self.conv3_3 = Conv2d_CHT(256, 256, 3, cht_config, padding=1)
        self.BN3_3 = nn.BatchNorm2d(256)

        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        # ---- GROUP 4 ----
        self.conv4_1 = Conv2d_CHT(256, 512, 3, cht_config, padding=1)
        self.BN4_1 = nn.BatchNorm2d(512)

        self.conv4_2 = Conv2d_CHT(512, 512, 3, cht_config, padding=1)
        self.BN4_2 = nn.BatchNorm2d(512)

        self.conv4_3 = Conv2d_CHT(512, 512, 3, cht_config, padding=1)
        self.BN4_3 = nn.BatchNorm2d(512)

        self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # ---- GROUP 5 ----
        self.conv5_1 = Conv2d_CHT(512, 512, 3, cht_config, padding=1)
        self.BN5_1 = nn.BatchNorm2d(512)

        self.conv5_2 = Conv2d_CHT(512, 512, 3, cht_config, padding=1)
        self.BN5_2 = nn.BatchNorm2d(512)

        self.conv5_3 = Conv2d_CHT(512, 512, 3, cht_config, padding=1)
        self.BN5_3 = nn.BatchNorm2d(512)

        self.maxpool5 = nn.MaxPool2d(kernel_size=2, stride=2)  
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # 保持最后变成 1x1

        # classifier
        self.one_fc=one_fc
        if not one_fc:
            self.fc1 = nn.Linear(512 * 1 * 1, 512, bias=True)
            self.fc2 = nn.Linear(512, 512, bias=True)
        self.last_layer = nn.Linear(512, num_classes, bias=True)

        self.relu = F.relu
        self.max_active = [0] * 16

        self._initialize_linear_weights()

    def _initialize_linear_weights(self):
        for m in self.modules():
            if isinstance(m, Conv2d_CHT) or isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def init_thresh(self, x: torch.Tensor):
        out = self.conv1_1(x); out = self.BN1_1(out); out = self.relu(out); self.max_active[0] = torch.zeros_like(out)
        out = self.conv1_2(out); out = self.BN1_2(out); out = self.relu(out); self.max_active[1] = torch.zeros_like(out)
        out = self.maxpool1(out)
        out = self.conv2_1(out); out = self.BN2_1(out); out = self.relu(out); self.max_active[2] = torch.zeros_like(out)
        out = self.conv2_2(out); out = self.BN2_2(out); out = self.relu(out); self.max_active[3] = torch.zeros_like(out)
        out = self.maxpool2(out)
        out = self.conv3_1(out); out = self.BN3_1(out); out = self.relu(out); self.max_active[4] = torch.zeros_like(out)
        out = self.conv3_2(out); out = self.BN3_2(out); out = self.relu(out); self.max_active[5] = torch.zeros_like(out)
        out = self.conv3_3(out); out = self.BN3_3(out); out = self.relu(out); self.max_active[6] = torch.zeros_like(out)
        out = self.maxpool3(out)
        out = self.conv4_1(out); out = self.BN4_1(out); out = self.relu(out); self.max_active[7] = torch.zeros_like(out)
        out = self.conv4_2(out); out = self.BN4_2(out); out = self.relu(out); self.max_active[8] = torch.zeros_like(out)
        out = self.conv4_3(out); out = self.BN4_3(out); out = self.relu(out); self.max_active[9] = torch.zeros_like(out)
        out = self.maxpool4(out)
        out = self.conv5_1(out); out = self.BN5_1(out); out = self.relu(out); self.max_active[10] = torch.zeros_like(out)
        out = self.conv5_2(out); out = self.BN5_2(out); out = self.relu(out); self.max_active[11] = torch.zeros_like(out)
        out = self.conv5_3(out); out = self.BN5_3(out); out = self.relu(out); self.max_active[12] = torch.zeros_like(out)
        out = self.maxpool5(out)
        out = self.avgpool(out)
        out = out.view(x.size(0), -1)
        if not self.one_fc:
            out = self.fc1(out); out = self.relu(out); self.max_active[13] = torch.zeros_like(out)
            out = self.fc2(out); out = self.relu(out); self.max_active[14] = torch.zeros_like(out)
            out = self.last_layer(out); self.max_active[15] = torch.zeros_like(out)
        else:
            out = self.last_layer(out); self.max_active[13] = torch.zeros_like(out)

    def forward(self, x: torch.Tensor):
        # Block 1.1
        out = self.conv1_1(x); out = self.BN1_1(out); out = self.relu(out)
        self.max_active[0] = torch.maximum(self.max_active[0], out)

        # Block 1.2
        out = self.conv1_2(out); out = self.BN1_2(out); out = self.relu(out)
        self.max_active[1] = torch.maximum(self.max_active[1], out)

        out = self.maxpool1(out)

        # Block 2.1
        out = self.conv2_1(out); out = self.BN2_1(out); out = self.relu(out)
        self.max_active[2] = torch.maximum(self.max_active[2], out)

        # Block 2.2
        out = self.conv2_2(out); out = self.BN2_2(out); out = self.relu(out)
        self.max_active[3] = torch.maximum(self.max_active[3], out)

        out = self.maxpool2(out)

        # Block 3.1
        out = self.conv3_1(out); out = self.BN3_1(out); out = self.relu(out)
        self.max_active[4] = torch.maximum(self.max_active[4], out)

        # Block 3.2
        out = self.conv3_2(out); out = self.BN3_2(out); out = self.relu(out)
        self.max_active[5] = torch.maximum(self.max_active[5], out)

        # Block 3.3
        out = self.conv3_3(out); out = self.BN3_3(out); out = self.relu(out)
        self.max_active[6] = torch.maximum(self.max_active[6], out)

        out = self.maxpool3(out)

        # Block 4.1
        out = self.conv4_1(out); out = self.BN4_1(out); out = self.relu(out)
        self.max_active[7] = torch.maximum(self.max_active[7], out)

        # Block 4.2
        out = self.conv4_2(out); out = self.BN4_2(out); out = self.relu(out)
        self.max_active[8] = torch.maximum(self.max_active[8], out)

        # Block 4.3
        out = self.conv4_3(out); out = self.BN4_3(out); out = self.relu(out)
        self.max_active[9] = torch.maximum(self.max_active[9], out)

        out = self.maxpool4(out)

        # Block 5.1
        out = self.conv5_1(out); out = self.BN5_1(out); out = self.relu(out)
        self.max_active[10] = torch.maximum(self.max_active[10], out)

        # Block 5.2
        out = self.conv5_2(out); out = self.BN5_2(out); out = self.relu(out)
        self.max_active[11] = torch.maximum(self.max_active[11], out)

        # Block 5.3
        out = self.conv5_3(out); out = self.BN5_3(out); out = self.relu(out)
        self.max_active[12] = torch.maximum(self.max_active[12], out)

        out = self.maxpool5(out)
        out = self.avgpool(out)

        # FC layers: out shape [B, D], take max over batch (dim=0) -> [D]
        out = out.view(x.size(0), -1)

        if not self.one_fc:
            out = self.fc1(out); out = self.relu(out)
            self.max_active[13] = torch.maximum(self.max_active[13], out)

            out = self.fc2(out); out = self.relu(out)
            self.max_active[14] = torch.maximum(self.max_active[14], out)

            out = self.last_layer(out)
            self.max_active[15] = torch.maximum(self.max_active[15], out)
        else:
            out = self.last_layer(out)
            self.max_active[13] = torch.maximum(self.max_active[13], out)

        return out
