import avalanche.models
import  torch.nn as nn
from avalanche.models.dynamic_modules import (
    MultiTaskModule
)
import torch


class basic_cnn_net_batchnorm_multheadandwei(MultiTaskModule):
    def __init__(self,image_size=32,channel=3):
        super(basic_cnn_net_batchnorm_multheadandwei,self).__init__()
        self.c1 = nn.ModuleList([nn.Conv2d(in_channels=channel, out_channels=32, bias=False,kernel_size=4)
                                          for _ in range(10)])
        self.relu = nn.ReLU()
        self.pool = nn.AvgPool2d(kernel_size=2)
        # self.bn1 = nn.BatchNorm2d(num_features=16,affine=True,track_running_stats=False)
        # self.bn2 = nn.BatchNorm2d(num_features=32, affine=True,track_running_stats=False)
        # self.bn3 = nn.BatchNorm2d(num_features=64, affine=True,track_running_stats=False)
        # self.bn4 = nn.BatchNorm1d(num_features=2048,affine=True,track_running_stats=False)
        # self.bn5 = nn.BatchNorm1d(num_features=2048,affine=True,track_running_stats=False)
        # self.bn4 = nn.BatchNorm1d(num_features=2048,affine=False)
        # self.bn5 = nn.BatchNorm1d(num_features=2048,affine=False)
        self.c2=nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,bias=False)
        self.c3=nn.Conv2d(in_channels=64,out_channels=128,kernel_size=2,bias=False)
        self.l1=nn.Linear(in_features=512,out_features=2048,bias=False)
        self.l2=nn.Linear(in_features=2048,out_features=2048,bias=False)
        # self.last=nn.Linear(in_features=2048,out_features=10,bias=False)
        self.dropout1=nn.Dropout(p=0.2)
        self.dropout2=nn.Dropout(p=0.5)
        self.floor=nn.Flatten()
        self.output=[]
        self.last = nn.ModuleList([nn.Linear(in_features=2048, out_features=10, bias=False)
                                          for _ in range(10)])
    def forward(self,x,task_numbers):
        x1 = self.c1[task_numbers[0].item()](x)
        # x2 = self.bn1(x1)
        x3 = self.dropout1(self.relu(x1))
        x4 = self.pool(x3)

        x5 = self.c2(x4)
        # x6 = self.bn2(x5)
        x7 = self.dropout1(self.relu(x5))
        x8 = self.pool(x7)

        x9 = self.c3(x8)
        # x10 = self.bn3(x9)
        x11 = self.pool(self.dropout2(self.relu(x9)))
        x12 = self.floor(x11)

        x13 = self.l1(x12)
        x14 = self.relu(x13)
        x15 = self.dropout2(x14)

        x16 = self.l2(x15)
        x17 = self.relu(x16)
        x18 = self.dropout2(x17)

        x19 = self.last[task_numbers[0].item()](x18)
        return x19
    def get_mid_value(self,x,task_numbers):
        batch_recoder=[]
        x1 = self.c1[task_numbers[0].item()](x)
        # x2 = self.bn1(x1)
        x3 = self.dropout1(self.relu(x1))
        x4 = self.pool(x3)

        x5 = self.c2(x4)
        # x6 = self.bn2(x5)
        x7 = self.dropout1(self.relu(x5))
        x8 = self.pool(x7)

        x9 = self.c3(x8)
        x10 = self.relu(x9)
        x11 = self.pool(self.dropout2(x10))
        x12 = self.floor(x11)

        x13 = self.l1(x12)
        x14 = self.relu(x13)
        x15 = self.dropout2(x14)

        x16 = self.l2(x15)
        x17 = self.relu(x16)
        x18 = self.dropout2(x17)

        x19 = self.last[task_numbers[0].item()](x18)
        batch_recoder=[x3,x7,x10,x14]

        combs = {
            10: {'input': x4, 'pre_layer': 'head'},
            11: {'input': x8, 'pre_layer': 16},
            12: {'input': x12, 'pre_layer': 17},
            13: {'input': x15, 'pre_layer': 18},
        }
        return combs, x19,batch_recoder
class basic_cnn_net_batchnorm(MultiTaskModule):
    def __init__(self,image_size=32,channel=3):
        super(basic_cnn_net_batchnorm,self).__init__()
        self.c1=nn.Conv2d(in_channels=channel,out_channels=64,kernel_size=4,bias=False)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.bn1 = nn.BatchNorm2d(num_features=64,affine=True,track_running_stats= False)
        self.bn2 = nn.BatchNorm2d(num_features=128,affine=True,track_running_stats=False)
        self.bn3 = nn.BatchNorm2d(num_features=256,affine=True,track_running_stats=False)

        # self.bn4 = nn.BatchNorm1d(num_features=2048,affine=False)
        # self.bn5 = nn.BatchNorm1d(num_features=2048,affine=False)
        self.c2=nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,bias=False)
        self.c3=nn.Conv2d(in_channels=128,out_channels=256,kernel_size=2,bias=False)
        self.l1=nn.Linear(in_features=1024,out_features=2048,bias=False)
        self.l2=nn.Linear(in_features=2048,out_features=2048,bias=False)
        # self.last=nn.Linear(in_features=2048,out_features=10,bias=False)
        self.dropout1=nn.Dropout(p=0.2)
        self.dropout2=nn.Dropout(p=0.5)
        self.floor=nn.Flatten()
        self.output=[]
        self.last = nn.ModuleList([nn.Linear(in_features=2048, out_features=10, bias=False)
                                          for _ in range(10)])
    def forward(self,x,task_numbers):

        x1 = self.c1(x)
        x2 = self.bn1(x1)
        x3 =self.relu(x2)
        x4 = self.pool(x3)

        x5 = self.c2(x4)
        x6 = self.bn2(x5)
        x7= self.relu(x6)
        x8 =self.pool(x7)

        x9 = self.c3(x8)
        x10 = self.bn3(x9)
        x11 = self.pool(self.relu(x10))
        x12=self.floor(x11)

        x13 = self.l1(x12)
        x14 = self.relu(x13)
        x15 = self.dropout2(x14)

        x16 = self.l2(x15)
        x17 = self.relu(x16)
        x18 = self.dropout2(x17)

        x19=self.last[task_numbers[0].item()](x18)

        return x19
    def get_mid_value(self,x,task_numbers):
        x1 = self.c1(x)
        x2 = self.bn1(x1)
        x3 = self.relu(x2)
        x4 = self.pool(x3)

        x5 = self.c2(x4)
        x6 = self.bn2(x5)
        x7 = self.relu(x6)
        x8 = self.pool(x7)

        x9 = self.c3(x8)
        x10 = self.bn3(x9)
        x11 = self.pool(self.relu(x10))
        x12 = self.floor(x11)

        x13 = self.l1(x12)
        x14 = self.relu(x13)
        x15 = self.dropout2(x14)

        x16 = self.l2(x15)
        x17 = self.relu(x16)
        x18 = self.dropout2(x17)

        x19 = self.last[task_numbers[0].item()](x18)


        combs = {
            0:{'input':x},
            8: {'input': x4, 'pre_layer': 'head'},
            9: {'input': x8, 'pre_layer': 16},
            10: {'input': x12, 'pre_layer': 17},
            11: {'input': x15, 'pre_layer': 18},
        }
        return combs, x19

# model=basic_cnn_net_batchnorm_multheadandwei()
# print(model)
# for name,parm in model.named_parameters():
#     print(name,parm.size())
class Bench_net(nn.Module):
    def __init__(self):
        super(Bench_net,self).__init__()
        self.conv1=nn.Conv2d(in_channels=1,out_channels=32,kernel_size=5)
        self.conv2=nn.Conv2d(in_channels=32,out_channels=64,kernel_size=5)
        self.pool1=nn.MaxPool2d(kernel_size=2,stride=2)
        self.pool2=nn.MaxPool2d(kernel_size=2,stride=2)
        self.batch_norm1=nn.BatchNorm2d(num_features=32)
        self.batch_norm2=nn.BatchNorm2d(num_features=64)
        self.linear=nn.Sequential(
            nn.Linear(in_features=1024, out_features=128),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=128, out_features=10)
        )
    def forward(self,x):
        x=self.conv1(x)
        x=self.batch_norm1(x)
        x=torch.nn.functional.relu(x)
        x=self.pool1(x)
        x=self.conv2(x)
        x=self.batch_norm2(x)
        x=torch.nn.functional.relu(x)
        x=self.pool2(x)
        x=x.view(-1,1024)
        x=self.linear(x)
        return x