import random
from backbone.cil_sparsity import *

class MySequential(nn.Sequential):
    def forward(self, input, session=0, args=None):
        for module in self:
            if isinstance(module, Layer):
                input = module(input, session=session, args=args)
            else:
                input = module(input)
        return input

class VGG9SNN(nn.Module):
    def __init__(self, time_step=4, session=0, adapt_ratio=0.1, lr=0.01,args=None):
        super(VGG9SNN, self).__init__()
        pool = SeqToANNContainer(nn.AvgPool2d(2))
        # pool = APLayer(2)

        self.features = MySequential(
            Layer(3, 64, 3, 1, 1),
            Layer(64, 64, 3, 1, 1),
            pool,
            Layer(64, 128, 3, 1, 1),
            Layer(128, 128, 3, 1, 1),
            pool,
            Layer(128, 256, 3, 1, 1),
            Layer(256, 256, 3, 1, 1),
            Layer(256, 256, 3, 1, 1),
            pool
        )
        if args.dataset == 'cifar100':
            size=32
        elif args.dataset == 'mini_imagenet':
            size=84
        W = int( size/ 2 / 2 / 2)
        self.T = time_step
        self.adapt_ratio = adapt_ratio
        self.lr = lr
        self.classifier = SeqToANNContainer(nn.Linear(256 * W * W, 1024))
        self.args = args
        # 每个 Layer 的旧类平均发放率
        self.base_rates = [0.0 for _ in self.features if isinstance(_, Layer)]

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, input, session=0, args=None):
        input = add_dimention(input, self.T)
        x = self.features(input, session=session,args=args)
        x = torch.flatten(x, 2)

        x = self.classifier(x)
        # print("after classfier x.shape=",x.shape)
        return x



if __name__ == '__main__':
    pass
