import torch
import torch.nn as nn
import torch.nn.functional as F
from .spikeLayer import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class VGG16_BN(nn.Module):
    def __init__(self):
        super(VGG16_BN, self).__init__()
        # GROUP 1
        self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1,
                                 padding=(1, 1), bias=True)
        self.BN1_1 = nn.BatchNorm2d(num_features=64)

        self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1,
                                 padding=(1, 1), bias=True)
        self.BN1_2 = nn.BatchNorm2d(num_features=64)

        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # GROUP 2
        self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1,
                                 padding=(1, 1), bias=True)
        self.BN2_1 = nn.BatchNorm2d(num_features=128)

        self.conv2_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1,
                                 padding=(1, 1), bias=True)
        self.BN2_2 = nn.BatchNorm2d(num_features=128)

        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # GROUP 3
        self.conv3_1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1,
                                 padding=(1, 1), bias=True)
        self.BN3_1 = nn.BatchNorm2d(num_features=256)

        self.conv3_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1,
                                 padding=(1, 1), bias=True)
        self.BN3_2 = nn.BatchNorm2d(num_features=256)

        self.conv3_3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1,
                                 bias=True)

        self.BN3_3 = nn.BatchNorm2d(num_features=256)

        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        # GROUP 4
        self.conv4_1 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1,
                                 padding=1, bias=True)
        self.BN4_1 = nn.BatchNorm2d(num_features=512)

        self.conv4_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1,
                                 padding=1, bias=True)
        self.BN4_2 = nn.BatchNorm2d(num_features=512)

        self.conv4_3 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1,
                                 bias=True)

        self.BN4_3 = nn.BatchNorm2d(num_features=512)

        self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        # GROUP 5
        self.conv5_1 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1,
                                 padding=1, bias=True)
        self.BN5_1 = nn.BatchNorm2d(num_features=512)

        self.conv5_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1,
                                 padding=1, bias=True)
        self.BN5_2 = nn.BatchNorm2d(num_features=512)

        self.conv5_3 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1,
                                 bias=True)

        self.BN5_3 = nn.BatchNorm2d(num_features=512)

        self.maxpool5 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(in_features=512 * 1 * 1, out_features=4096, bias=True)

        self.fc2 = nn.Linear(in_features=4096, out_features=4096, bias=True)

        self.fc3 = nn.Linear(in_features=4096, out_features=10, bias=True)

        self.relu = F.relu

    def forward(self, x, epoch):
        # GROUP 1
        output = self.conv1_1(x)
        output = self.BN1_1(output)
        output = self.relu(output)

        output = self.conv1_2(output)
        output = self.BN1_2(output)
        output = self.relu(output)

        output = self.maxpool1(output)
        # GROUP 2
        output = self.conv2_1(output)
        output = self.BN2_1(output)
        output = self.relu(output)

        output = self.conv2_2(output)
        output = self.BN2_2(output)
        output = self.relu(output)

        output = self.maxpool2(output)
        # GROUP 3
        output = self.conv3_1(output)
        output = self.BN3_1(output)
        output = self.relu(output)

        output = self.conv3_2(output)
        output = self.BN3_2(output)
        output = self.relu(output)

        output = self.conv3_3(output)
        output = self.BN3_3(output)
        output = self.relu(output)

        output = self.maxpool3(output)
        # GROUP 4

        output = self.conv4_1(output)
        output = self.BN4_1(output)
        output = self.relu(output)

        output = self.conv4_2(output)
        output = self.BN4_2(output)
        output = self.relu(output)

        output = self.conv4_3(output)
        output = self.BN4_3(output)
        output = self.relu(output)

        output = self.maxpool4(output)
        # GROUP 5
        output = self.conv5_1(output)
        output = self.BN5_1(output)
        output = self.relu(output)

        output = self.conv5_2(output)
        output = self.BN5_2(output)
        output = self.relu(output)

        output = self.conv5_3(output)
        output = self.BN5_3(output)
        output = self.relu(output)

        output = self.maxpool5(output)

        output = output.view(x.size(0), -1)
        output = self.fc1(output)
        output = self.relu(output)
        output = self.fc2(output)
        output = self.relu(output)
        output = self.fc3(output)
        return output


class VGG16_optimalThres(nn.Module):
    """
    VGG16 for CIFAR, attribute-per-layer style, using nn.Conv2d for all convs,
    MaxPool2d for pooling, kernel_size=3, padding=1 everywhere (as requested).
    """
    def __init__(self,  num_classes: int,  one_fc):
        super().__init__()

        # ---- GROUP 1 ----
        self.conv1_1 = nn.Conv2d(3, 64, 3,  padding=1)
        self.BN1_1 = nn.BatchNorm2d(64)

        self.conv1_2 = nn.Conv2d(64, 64, 3,  padding=1)
        self.BN1_2 = nn.BatchNorm2d(64)

        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        # ---- GROUP 2 ----
        self.conv2_1 = nn.Conv2d(64, 128, 3,  padding=1)
        self.BN2_1 = nn.BatchNorm2d(128)

        self.conv2_2 = nn.Conv2d(128, 128, 3,  padding=1)
        self.BN2_2 = nn.BatchNorm2d(128)

        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # ---- GROUP 3 ----
        self.conv3_1 = nn.Conv2d(128, 256, 3,  padding=1)
        self.BN3_1 = nn.BatchNorm2d(256)

        self.conv3_2 = nn.Conv2d(256, 256, 3,  padding=1)
        self.BN3_2 = nn.BatchNorm2d(256)

        # conv3_3 保持 3x3 (padding=1)
        self.conv3_3 = nn.Conv2d(256, 256, 3,  padding=1)
        self.BN3_3 = nn.BatchNorm2d(256)

        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        # ---- GROUP 4 ----
        self.conv4_1 = nn.Conv2d(256, 512, 3,  padding=1)
        self.BN4_1 = nn.BatchNorm2d(512)

        self.conv4_2 = nn.Conv2d(512, 512, 3,  padding=1)
        self.BN4_2 = nn.BatchNorm2d(512)

        self.conv4_3 = nn.Conv2d(512, 512, 3,  padding=1)
        self.BN4_3 = nn.BatchNorm2d(512)

        self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # ---- GROUP 5 ----
        self.conv5_1 = nn.Conv2d(512, 512, 3,  padding=1)
        self.BN5_1 = nn.BatchNorm2d(512)

        self.conv5_2 = nn.Conv2d(512, 512, 3,  padding=1)
        self.BN5_2 = nn.BatchNorm2d(512)

        self.conv5_3 = nn.Conv2d(512, 512, 3,  padding=1)
        self.BN5_3 = nn.BatchNorm2d(512)

        self.maxpool5 = nn.MaxPool2d(kernel_size=2, stride=2)  
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # 保持最后变成 1x1

        # classifier
        self.one_fc=one_fc
        if not one_fc:
            self.fc1 = nn.Linear(512 * 1 * 1, 512, bias=True)
            self.fc2 = nn.Linear(512, 512, bias=True)
        self.last_layer = nn.Linear(512, num_classes, bias=True)

        self.relu = F.relu
        self.max_active = [0] * 16 if not one_fc else [0]*14

        self._initialize_linear_weights()
        
    def _initialize_linear_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def init_thresh(self, x: torch.Tensor):
        out = self.conv1_1(x); out = self.BN1_1(out); out = self.relu(out); self.max_active[0] = torch.zeros_like(out)
        out = self.conv1_2(out); out = self.BN1_2(out); out = self.relu(out); self.max_active[1] = torch.zeros_like(out)
        out = self.maxpool1(out)
        out = self.conv2_1(out); out = self.BN2_1(out); out = self.relu(out); self.max_active[2] = torch.zeros_like(out)
        out = self.conv2_2(out); out = self.BN2_2(out); out = self.relu(out); self.max_active[3] = torch.zeros_like(out)
        out = self.maxpool2(out)
        out = self.conv3_1(out); out = self.BN3_1(out); out = self.relu(out); self.max_active[4] = torch.zeros_like(out)
        out = self.conv3_2(out); out = self.BN3_2(out); out = self.relu(out); self.max_active[5] = torch.zeros_like(out)
        out = self.conv3_3(out); out = self.BN3_3(out); out = self.relu(out); self.max_active[6] = torch.zeros_like(out)
        out = self.maxpool3(out)
        out = self.conv4_1(out); out = self.BN4_1(out); out = self.relu(out); self.max_active[7] = torch.zeros_like(out)
        out = self.conv4_2(out); out = self.BN4_2(out); out = self.relu(out); self.max_active[8] = torch.zeros_like(out)
        out = self.conv4_3(out); out = self.BN4_3(out); out = self.relu(out); self.max_active[9] = torch.zeros_like(out)
        out = self.maxpool4(out)
        out = self.conv5_1(out); out = self.BN5_1(out); out = self.relu(out); self.max_active[10] = torch.zeros_like(out)
        out = self.conv5_2(out); out = self.BN5_2(out); out = self.relu(out); self.max_active[11] = torch.zeros_like(out)
        out = self.conv5_3(out); out = self.BN5_3(out); out = self.relu(out); self.max_active[12] = torch.zeros_like(out)
        out = self.maxpool5(out)
        out = self.avgpool(out)
        out = out.view(x.size(0), -1)
        if not self.one_fc:
            out = self.fc1(out); out = self.relu(out); self.max_active[13] = torch.zeros_like(out)
            out = self.fc2(out); out = self.relu(out); self.max_active[14] = torch.zeros_like(out)
            out = self.last_layer(out); self.max_active[15] = torch.zeros_like(out)
        else:
            out = self.last_layer(out); self.max_active[13] = torch.zeros_like(out)

    def forward(self, x: torch.Tensor):
        # Block 1.1
        out = self.conv1_1(x); out = self.BN1_1(out); out = self.relu(out)
        self.max_active[0] = torch.maximum(self.max_active[0], out)

        # Block 1.2
        out = self.conv1_2(out); out = self.BN1_2(out); out = self.relu(out)
        self.max_active[1] = torch.maximum(self.max_active[1], out)

        out = self.maxpool1(out)

        # Block 2.1
        out = self.conv2_1(out); out = self.BN2_1(out); out = self.relu(out)
        self.max_active[2] = torch.maximum(self.max_active[2], out)

        # Block 2.2
        out = self.conv2_2(out); out = self.BN2_2(out); out = self.relu(out)
        self.max_active[3] = torch.maximum(self.max_active[3], out)

        out = self.maxpool2(out)

        # Block 3.1
        out = self.conv3_1(out); out = self.BN3_1(out); out = self.relu(out)
        self.max_active[4] = torch.maximum(self.max_active[4], out)

        # Block 3.2
        out = self.conv3_2(out); out = self.BN3_2(out); out = self.relu(out)
        self.max_active[5] = torch.maximum(self.max_active[5], out)

        # Block 3.3
        out = self.conv3_3(out); out = self.BN3_3(out); out = self.relu(out)
        self.max_active[6] = torch.maximum(self.max_active[6], out)

        out = self.maxpool3(out)

        # Block 4.1
        out = self.conv4_1(out); out = self.BN4_1(out); out = self.relu(out)
        self.max_active[7] = torch.maximum(self.max_active[7], out)

        # Block 4.2
        out = self.conv4_2(out); out = self.BN4_2(out); out = self.relu(out)
        self.max_active[8] = torch.maximum(self.max_active[8], out)

        # Block 4.3
        out = self.conv4_3(out); out = self.BN4_3(out); out = self.relu(out)
        self.max_active[9] = torch.maximum(self.max_active[9], out)

        out = self.maxpool4(out)

        # Block 5.1
        out = self.conv5_1(out); out = self.BN5_1(out); out = self.relu(out)
        self.max_active[10] = torch.maximum(self.max_active[10], out)

        # Block 5.2
        out = self.conv5_2(out); out = self.BN5_2(out); out = self.relu(out)
        self.max_active[11] = torch.maximum(self.max_active[11], out)

        # Block 5.3
        out = self.conv5_3(out); out = self.BN5_3(out); out = self.relu(out)
        self.max_active[12] = torch.maximum(self.max_active[12], out)

        out = self.maxpool5(out)
        out = self.avgpool(out)

        # FC layers: out shape [B, D], take max over batch (dim=0) -> [D]
        out = out.view(x.size(0), -1)

        if not self.one_fc:
            out = self.fc1(out); out = self.relu(out)
            self.max_active[13] = torch.maximum(self.max_active[13], out)

            out = self.fc2(out); out = self.relu(out)
            self.max_active[14] = torch.maximum(self.max_active[14], out)

            out = self.last_layer(out)
            self.max_active[15] = torch.maximum(self.max_active[15], out)
        else:
            out = self.last_layer(out)
            self.max_active[13] = torch.maximum(self.max_active[13], out)

        return out


class VGG16_BN_PosNeg_spiking(nn.Module):
    '''
    通过传入optimalThres模型转换成SNN的类。对dense的CHT都适用
    '''
    def __init__(self, thresh_list, model):
        super().__init__()
        # group1

        self.conv1_1 = SPIKE_PosNeg_layer_BN(thresh_list[0], -thresh_list[0], model.conv1_1, model.BN1_1)

        self.conv1_2 = SPIKE_PosNeg_layer_BN(thresh_list[1], -thresh_list[1], model.conv1_2, model.BN1_2)

        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # group2

        self.conv2_1 = SPIKE_PosNeg_layer_BN(thresh_list[2], -thresh_list[2], model.conv2_1, model.BN2_1)

        self.conv2_2 = SPIKE_PosNeg_layer_BN(thresh_list[3], -thresh_list[3], model.conv2_2, model.BN2_2)

        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # group3

        self.conv3_1 = SPIKE_PosNeg_layer_BN(thresh_list[4], -thresh_list[4], model.conv3_1, model.BN3_1)

        self.conv3_2 = SPIKE_PosNeg_layer_BN(thresh_list[5], -thresh_list[5], model.conv3_2, model.BN3_2)

        self.conv3_3 = SPIKE_PosNeg_layer_BN(thresh_list[6], -thresh_list[6], model.conv3_3, model.BN3_3)

        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        # group4

        self.conv4_1 = SPIKE_PosNeg_layer_BN(thresh_list[7], -thresh_list[7], model.conv4_1, model.BN4_1)

        self.conv4_2 = SPIKE_PosNeg_layer_BN(thresh_list[8], -thresh_list[8], model.conv4_2, model.BN4_2)

        self.conv4_3 = SPIKE_PosNeg_layer_BN(thresh_list[9], -thresh_list[9], model.conv4_3, model.BN4_3)

        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        # group5

        self.conv5_1 = SPIKE_PosNeg_layer_BN(thresh_list[10], -thresh_list[10], model.conv5_1, model.BN5_1)

        self.conv5_2 = SPIKE_PosNeg_layer_BN(thresh_list[11], -thresh_list[11], model.conv5_2, model.BN5_2)

        self.conv5_3 = SPIKE_PosNeg_layer_BN(thresh_list[12], -thresh_list[12], model.conv5_3, model.BN5_3)

        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))

        self.one_fc = model.one_fc

        if not self.one_fc:
            self.fc1 = SPIKE_PosNeg_layer(thresh_list[13], -thresh_list[13], model.fc1)
            self.fc2 = SPIKE_PosNeg_layer(thresh_list[14], -thresh_list[14], model.fc2)
            self.last_layer = SPIKE_PosNeg_layer(thresh_list[15], -thresh_list[15], model.last_layer)
        else:
            self.last_layer = SPIKE_PosNeg_layer(thresh_list[13], -thresh_list[13], model.last_layer)

    def weight_bias_norm(self):
        self.conv1_1.compute_Conv_weight()
        self.conv1_2.compute_Conv_weight()
        self.conv2_1.compute_Conv_weight()
        self.conv2_2.compute_Conv_weight()
        self.conv3_1.compute_Conv_weight()
        self.conv3_2.compute_Conv_weight()
        self.conv3_3.compute_Conv_weight()
        self.conv4_1.compute_Conv_weight()
        self.conv4_2.compute_Conv_weight()
        self.conv4_3.compute_Conv_weight()
        self.conv5_1.compute_Conv_weight()
        self.conv5_2.compute_Conv_weight()
        self.conv5_3.compute_Conv_weight()

    def init_layer(self):
        self.conv1_1.init_mem()
        self.conv1_2.init_mem()
        self.conv2_1.init_mem()
        self.conv2_2.init_mem()
        self.conv3_1.init_mem()
        self.conv3_2.init_mem()
        self.conv3_3.init_mem()
        self.conv4_1.init_mem()
        self.conv4_2.init_mem()
        self.conv4_3.init_mem()
        self.conv5_1.init_mem()
        self.conv5_2.init_mem()
        self.conv5_3.init_mem()
        if not self.one_fc:
            self.fc1.init_mem()
            self.fc2.init_mem()
        self.last_layer.init_mem()

    def forward(self, x, time):
        spike_input = x
        output, m1_1 = self.conv1_1(spike_input, time)
        output, m1_2 = self.conv1_2(output, time)
        output = self.pool1(output)
        # group 2
        output, m2_1 = self.conv2_1(output, time)
        output, m2_2 = self.conv2_2(output, time)
        output = self.pool2(output)
        # group 3
        output, m3_1 = self.conv3_1(output, time)
        output, m3_2 = self.conv3_2(output, time)
        output, m3_3 = self.conv3_3(output, time)
        output = self.pool3(output)
        # group 4
        output, m4_1 = self.conv4_1(output, time)
        output, m4_2 = self.conv4_2(output, time)
        output, m4_3 = self.conv4_3(output, time)
        output = self.pool4(output)
        # group 5
        output, m5_1 = self.conv5_1(output, time)
        output, m5_2 = self.conv5_2(output, time)
        output, m5_3 = self.conv5_3(output, time)
        output = self.pool5(output)
        #
        output = self.avgpool(output)
        output = output.view(output.size(0), -1)
        if not self.one_fc:
            output, mfc1 = self.fc1(output, time)
            output, mfc2 = self.fc2(output, time)
        output, mfc3 = self.last_layer(output, time)

        return output


    
'''
backup for SNN range(T)

self.init_layer()
with torch.no_grad():
    out_spike_sum = 0
    for time in range(self.T):
        #forward
    
    out_spike_sum += output
    if (time + 1) == self.T:
        sub_result = out_spike_sum / (time + 1)
    return sub_result #rate coding
'''
