'''
LWTA_Net for incremental learning
'''
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from LWTA.layers import LWTA


class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1
   
    def __init__(self, in_planes, planes, stride=1, iter=1, num_layer=1, base_train=True, temp_train=0.67, J=2):
        super(PreActBlock, self).__init__()
        self.J = J        
        self.batch_size = 40
        self.base_train = base_train
        self.temp_train = temp_train
        
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.lwta1 = LWTA()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.lwta2 = LWTA()
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        
        for name, param in self.state_dict().items():
            if (iter == 1) and ("bn1.bias" in name):
                shape_1 = param.shape[0]
            if (iter == 1) and ("bn2.bias" in name):
                shape_2 = param.shape[0]
            if (iter == 2) and ("bn1.bias" in name):
                shape_1 = param.shape[0]
            if (iter == 2) and ("bn2.bias" in name):
                shape_2 = param.shape[0]
            if num_layer == 1:
                shape_3, shape_4 = 32, 32
            elif num_layer == 2:
                if (iter == 1):
                    shape_3 = 32
                    shape_4 = 16
                else:
                    shape_3 = shape_4 = 16
            elif num_layer == 3:
                if (iter == 1):
                    shape_3 = 16
                    shape_4 = 8
                else:
                    shape_3, shape_4 = 8, 8
            elif num_layer == 4:
                if (iter == 1):
                    shape_3 = 8
                    shape_4 = 4
                else:
                    shape_3, shape_4 = 4, 4
      
    
        self.p_ti_1 = nn.Parameter(torch.randn(self.batch_size, shape_1, shape_3, shape_3), requires_grad=True)
        nn.init.xavier_normal_(self.p_ti_1)

        self.p_ti_2 = nn.Parameter(torch.randn(self.batch_size, shape_2, shape_4, shape_4), requires_grad=True)
        nn.init.xavier_normal_(self.p_ti_2)
        
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )
        
    def forward(self, x):
        out = self.lwta1(self.bn1(x), p_ti=self.p_ti_1, base_train=self.base_train, temp_train=self.temp_train, J=self.J)
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.lwta2(self.bn2(self.conv2(out)), p_ti=self.p_ti_2, base_train=self.base_train, temp_train=self.temp_train, J=self.J)
        out += shortcut
        return out
    
class PreActResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes_per_classifier=2, num_classifier=5, J=2):
        super(PreActResNet, self).__init__()
        self.J = J
        self.in_planes = 32
        self.base_train = True
        self.temp_train = 0.67
        
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer1 = self._make_layer(block, 32, num_blocks[0], stride=1, num_layer=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2, num_layer=2)
        self.layer3 = self._make_layer(block, 16, num_blocks[2], stride=2, num_layer=3)
        self.layer4 = self._make_layer(block, 16, num_blocks[3], stride=2, num_layer=4)
        self.bn = nn.BatchNorm2d(16)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear_main = nn.ModuleList([nn.Linear(16*block.expansion, num_classes_per_classifier) for i in range(num_classifier)])
        self.linear = nn.ModuleList([nn.Linear(16*block.expansion, num_classes_per_classifier) for i in range(num_classifier)])

    def _make_layer(self, block, planes, num_blocks, stride, num_layer):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        i = 1
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, iter=i, num_layer = num_layer,
                                base_train=self.base_train, temp_train=self.temp_train, J=self.J))
            self.in_planes = planes * block.expansion
            i += 1
        return nn.Sequential(*layers)

    def forward(self, x, out_idx=0, main_fc=False, is_feature = False):

        assert out_idx or is_feature

        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.bn(out)
        out = self.avgpool(out)
        feature = out.view(out.size(0), -1)

        if is_feature:
            output = feature

        else:
            output_list = []

            for idx in range(out_idx):

                if main_fc:                    
                    output_list.append(self.linear_main[idx](feature))

                else:
                    output_list.append(self.linear[idx](feature))

            output = torch.cat(output_list, dim=1)        

        return output

def PreActResNet18(num_classes_per_classifier=2, num_classifier=5, J=2):
    return PreActResNet(PreActBlock, [2,2,2,2], num_classes_per_classifier=num_classes_per_classifier, num_classifier=num_classifier, J=J)

