


import torch.nn as nn
import torch.nn.functional as F

from avalanche.models.dynamic_modules import (
    MultiTaskModule
)




class gpm_net(MultiTaskModule):
    """Multi-layer perceptron with multi-head classifier"""

    def __init__(self, input_size=28 * 28, hidden_size=256, tasks=10,temperature=1.):
        super().__init__()


        # self.dense_first = nn.Linear(input_size, hidden_size, bias=False)

        self.features = nn.Sequential(
            nn.Linear(hidden_size, hidden_size, bias=False),
            # nn.ReLU(inplace=True)
        )
        self.classifier = nn.Linear(hidden_size, 10, bias=False)
        self.dense_first = nn.ModuleList([nn.Linear(in_features=input_size, out_features=hidden_size, bias=False)
                                          for _ in range(tasks)])
        # self.classifier = MultiHeadClassifier(hidden_size)
        self.temperature = temperature
    def renew_temperature(self, temperature_in=None):
        self.temperature = temperature_in

    def forward(self, x, task_labels):

        task_labels = task_labels[0].item()
        # print( task_labels )

        x = x.contiguous()
        x = x.view(x.size(0), 784)

        x1 = self.dense_first[task_labels](x)
        # x = self.dense_first(x)
        x2 = F.relu(x1)
        x = self.features(x2)
        x3 = F.relu(x)
        x4 = self.classifier(x3)
        x4=x4/self.temperature
        return x4

    def get_mid(self, x, task_labels):

        task_labels = task_labels[0].item()
        # print( task_labels )

        x = x.contiguous()
        x = x.view(x.size(0), 784)

        x1 = self.dense_first[task_labels](x)
        # x = self.dense_first(x)
        x2 = F.relu(x1)
        x = self.features(x2)
        x3 = F.relu(x)
        x4 = self.classifier(x3)
        x4=x4/self.temperature
        return [x2, x3, x4]

class gpm_net_(nn.Module):
    def __init__(self,hidden_unit=256):
        super(gpm_net_,self).__init__()
        self.flatten=nn.Flatten()
        self.l2 = nn.Linear(in_features=hidden_unit, out_features=hidden_unit, bias=False)
        self.relu = nn.ReLU()
        self.last = nn.Linear(in_features=hidden_unit,out_features=10,bias=False)
        self.first=nn.ModuleList( [  nn.Linear( in_features=784, out_features= hidden_unit, bias=False )
                                       for _ in range(10) ] )
        self.grades_all_layers=[]
    def forward(self,x,task_number):
        x1=self.flatten(x)
        x2=self.first[task_number[0].item()](x1)
        x3=self.relu(x2)
        x4=self.l2(x3)
        x5=self.relu(x4)
        x6=self.last(x5)
        self.grades_all_layers = [x3, x5, x6]
        return self.grades_all_layers



class wpm_net( nn.Module ):
    def __init__(self, in_dim=784, hidden_unit=256, out_dim=10, tasks=10):
        super(wpm_net, self).__init__()
        self.relu = nn.ReLU()



        self.l2=nn.Linear(in_features=hidden_unit,out_features=hidden_unit,bias=False)

        # self.dense1 = nn.Linear( in_features=in_dim, out_features= hidden_unit, bias=False )
        # self.dense2 = nn.Linear( in_features=hidden_unit, out_features=hidden_unit, bias=False )
        # self.dense3 = nn.Linear( in_features=hidden_unit, out_features=hidden_unit, bias=False )
        self.last = nn.Linear(in_features=hidden_unit,out_features=out_dim,bias=False)
        self.dense_first = nn.ModuleList([nn.Linear(in_features=in_dim, out_features=hidden_unit, bias=False)
                                          for _ in range(tasks)])
        self.flatten = nn.Flatten()
    def forward( self, x, task_num ): #
        x1 = self.flatten(x)
        x2 = self.dense_first[task_num](x1)
        x3 = self.relu(x2)
        x4 = self.l2(x3)
        x5 = self.relu(x4)
        x6 = self.last(x5)
        return [x3, x5, x6]
#CNN
class basic_cnn_net(nn.Module):
    def __init__(self,image_size=32,channel=3):
        super(basic_cnn_net,self).__init__()
        self.c1=nn.Conv2d(in_channels=channel,out_channels=64,kernel_size=5,padding=2,bias=False)
        self.relu=nn.ReLU()
        self.pool=nn.MaxPool2d(kernel_size=2)
        self.c2=nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=1,bias=False)
        self.c3=nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,padding=1,bias=False)
        self.l1=nn.Linear(in_features=4096,out_features=2048,bias=False)
        self.l2=nn.Linear(in_features=2048,out_features=2048,bias=False)
        self.last=nn.Linear(in_features=2048,out_features=10,bias=False)
        self.dropout1=nn.Dropout(p=0.2)
        self.dropout2=nn.Dropout(p=0.5)
        self.floor=nn.Flatten(start_dim=-3,end_dim=-1)
        self.output=[]
    def forward(self,x):

        x1 = self.c1(x)
        x2 = self.relu(x1)
        x3 = self.dropout1(x2)
        x4 = self.pool(x3)

        x5 = self.c2(x4)
        x6 = self.relu(x5)
        x7 = self.dropout1(x6)
        x8 = self.pool(x7)

        x9 = self.c3(x8)
        x10 = self.relu(x9)
        x11 = self.dropout2(x10)
        x12 = self.pool(x11)
        x13=self.floor(x12)

        x14 = self.l1(x13)
        x15 = self.relu(x14)
        x16 = self.dropout2(x15)


        x17 = self.l2(x16)
        x18 = self.relu(x17)
        x19 = self.dropout2(x18)


        x20=self.last(x19)

        self.output=[x,x4,x8,x13,x16,x19,x20]
        return self.output

##VGG-16
class Vgg16_net(nn.Module):
    def __init__(self,image_size=32,channel=3):
        super(Vgg16_net, self).__init__()
        self.layer1 = nn.Sequential(
            # 输入3通道图像，输出64通道特征图，卷积核大小3x3，步长1，填充1
            nn.Conv2d(in_channels=channel, out_channels=64, kernel_size=3, stride=1, padding=1,bias=False),
            # 对64通道特征图进行Batch Normalization
            nn.BatchNorm2d(64,affine=False),
            # 对64通道特征图进行ReLU激活函数
            nn.ReLU(inplace=True)
        )
        self.layer2=nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1,bias=False),
            # 对64通道特征图进行Batch Normalization
            nn.BatchNorm2d(64,affine=False),
            # 对64通道特征图进行ReLU激活函数
            nn.ReLU(inplace=True),
            # 进行2x2的最大池化操作，步长为2
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.layer3 = nn.Sequential(
            # 输入64通道特征图，输出128通道特征图，卷积核大小3x3，步长1，填充1
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1,bias=False),
            # 对128通道特征图进行Batch Normalization
            nn.BatchNorm2d(128,affine=False),
            # 对128通道特征图进行ReLU激活函数
            nn.ReLU(inplace=True)
        )
        self.layer4 =nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1,bias=False),
            # 对128通道特征图进行Batch Normalization
            nn.BatchNorm2d(128,affine=False),
            nn.ReLU(inplace=True),
            # 进行2x2的最大池化操作，步长为2
            nn.MaxPool2d(2, 2)
        )
        self.layer5 = nn.Sequential(
            # 输入为128通道，输出为256通道，卷积核大小为33，步长为1，填充大小为1
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1,bias=False),
            # 批归一化
            nn.BatchNorm2d(256,affine=False),
            nn.ReLU(inplace=True)
        )
        self.layer6=nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1,bias=False),
            nn.BatchNorm2d(256,affine=False),
            nn.ReLU(inplace=True)
        )
        self.layer7=nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1,bias=False),
            nn.BatchNorm2d(256,affine=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )
        self.layer8 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1,bias=False),
            nn.BatchNorm2d(512,affine=False),
            nn.ReLU(inplace=True)
        )
        self.layer9=nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1,bias=False),
            nn.BatchNorm2d(512,affine=False),
            nn.ReLU(inplace=True)
        )
        self.layer10=nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1,bias=False),
            nn.BatchNorm2d(512,affine=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )
        self.layer11=nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1,bias=False),
            nn.BatchNorm2d(512,affine=False),
            nn.ReLU(inplace=True),
        )
        self.layer12 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1,bias=False),
            nn.BatchNorm2d(512,affine=False),
            nn.ReLU(inplace=True)
        )
        self.layer13=nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1,bias=False),
            nn.BatchNorm2d(512,affine=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Flatten(start_dim=-3,end_dim=-1)
        )
        self.fc1 = nn.Sequential(
            nn.Linear(int(512*((image_size/32)**2)), 512,bias=False),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )
        self.fc2=nn.Sequential(
            nn.Linear(512, 256,bias=False),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )
        self.last=nn.Linear(256,10,bias=False)
        self.out=[]
    def forward(self, x):
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        x5 = self.layer5(x4)
        x6 = self.layer6(x5)
        x7 = self.layer7(x6)
        x8 = self.layer8(x7)
        x9 = self.layer9(x8)
        x10 = self.layer10(x9)
        x11 = self.layer11(x10)
        x12 = self.layer12(x11)
        x13 = self.layer13(x12)
        x14 = self.fc1(x13)
        x15 = self.fc2(x14)
        x16 = self.last(x15)
        self.out=[x,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15,x16]
        return self.out

class basic_cnn_net_batchnorm(nn.Module):
    def __init__(self,image_size=32,channel=3,task_number=10):
        super(basic_cnn_net_batchnorm,self).__init__()
        self.c1=nn.Conv2d(in_channels=channel,out_channels=64,kernel_size=5,padding=2,bias=False)
        self.relu=nn.ReLU()
        self.pool=nn.MaxPool2d(kernel_size=2)
        self.bn1=nn.BatchNorm2d(num_features=64,affine=False)
        self.bn2 = nn.BatchNorm2d(num_features=128, affine=False)
        self.bn3 = nn.BatchNorm2d(num_features=256, affine=False)

        # self.bn4 = nn.BatchNorm1d(num_features=2048,affine=False)
        # self.bn5 = nn.BatchNorm1d(num_features=2048,affine=False)
        self.c2=nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=1,bias=False)
        self.c3=nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,padding=1,bias=False)
        self.l1=nn.Linear(in_features=16384,out_features=2048,bias=False)
        self.l2=nn.Linear(in_features=2048,out_features=2048,bias=False)
        # self.last=nn.Linear(in_features=2048,out_features=10,bias=False)
        self.dropout1=nn.Dropout(p=0)
        self.dropout2=nn.Dropout(p=0.5)
        self.floor=nn.Flatten(start_dim=-3,end_dim=-1)
        self.output=[]
        self.last = nn.ModuleList([nn.Linear(in_features=2048, out_features=10, bias=False)
                                          for _ in range(10)])
    def forward(self,x,task_number):

        x1 = self.c1(x)
        x2 = self.bn1(x1)
        x3 = self.relu(x2)
        x4 = self.pool(x3)

        x5 = self.c2(x4)
        x6 = self.bn2(x5)
        x7= self.relu(x6)
        x8 =self.pool(x7)

        x9 = self.c3(x8)
        x10 = self.bn3(x9)
        x11 = self.relu(x10)
        x12=self.floor(x11)

        x13 = self.l1(x12)
        x14 = self.relu(x13)
        x15 = self.dropout2(x14)

        x16 = self.l2(x15)
        x17 = self.relu(x16)
        x18 = self.dropout2(x17)

        x19=self.last[task_number](x18)

        self.output=[x,x4,x8,x14,x17,x19]
        return self.output
# model_=basic_cnn_net_batchnorm()
# print(len(list(model_.parameters())))
# print(list(model_.parameters()))
class basic_cnn_net_batchnorm_multheadandwei(nn.Module):
    def __init__(self,image_size=32,channel=3):
        super(basic_cnn_net_batchnorm_multheadandwei,self).__init__()
        self.c1=nn.ModuleList([nn.Conv2d(in_channels=channel, out_channels=64, bias=False,kernel_size=3,padding=1)
                                          for _ in range(10)])
        self.relu = nn.ReLU()
        self.pool = nn.AvgPool2d(kernel_size=2)
        self.bn1 = nn.BatchNorm2d(num_features=64,affine=True)
        self.bn2 = nn.BatchNorm2d(num_features=128, affine=True)
        self.bn3 = nn.BatchNorm2d(num_features=256, affine=True)

        # self.bn4 = nn.BatchNorm1d(num_features=2048,affine=False)
        # self.bn5 = nn.BatchNorm1d(num_features=2048,affine=False)
        self.c2=nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=1,bias=False)
        self.c3=nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,padding=1,bias=False)
        self.l1=nn.Linear(in_features=16384,out_features=2048,bias=False)
        self.l2=nn.Linear(in_features=2048,out_features=2048,bias=False)
        # self.last=nn.Linear(in_features=2048,out_features=10,bias=False)

        self.dropout2=nn.Dropout(p=0.5)
        self.floor=nn.Flatten(start_dim=-3,end_dim=-1)
        self.output=[]
        self.last = nn.ModuleList([nn.Linear(in_features=2048, out_features=10, bias=False)
                                          for _ in range(10)])
    def forward(self,x,task_number):

        x1 = self.c1[task_number](x)
        x2 = self.bn1(x1)
        x3 = self.relu(x2)
        x4 = self.pool(x3)

        x5 = self.c2(x4)
        x6 = self.bn2(x5)
        x7= self.relu(x6)
        x8 =self.pool(x7)

        x9 = self.c3(x8)
        x10 = self.bn3(x9)
        x11 = self.relu(x10)
        x12=self.floor(x11)

        x13 = self.l1(x12)
        x14 = self.relu(x13)
        x15 = self.dropout2(x14)

        x16 = self.l2(x15)
        x17 = self.relu(x16)
        x18 = self.dropout2(x17)

        x19=self.last[task_number](x18)



        combs = {
            16: {'input': x4, 'pre_layer': 'head'},
            17: {'input': x8, 'pre_layer': 16},
            18: {'input': x12, 'pre_layer': 17},
            19: {'input': x14, 'pre_layer': 18},


        }
        return  combs ,x19
# model=basic_cnn_net_batchnorm_multheadandwei()
# import torch
# # # for m in model.modules():
# # #     if isinstance(m, (nn.Conv2d, nn.Linear)):
# # #         nn.init.xavier_uniform_(m.weight)
# # # print(model.c1[0].weight.unsqueeze(dim=0).repeat((128,1,1,1,1)).size())
# for para in model.named_parameters():
#     a,b=para
#     print(a,b.size())