'''
LWTA_Net for incremental learning
'''
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from LWTA.layers import LWTA


class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1
   
    def __init__(self, in_planes, planes, stride=1, iter=1, base_train=True, temp_train=0.67, J=2):
        super(PreActBlock, self).__init__()
        self.J = J        
        self.batch_size = 40
        self.base_train = base_train
        self.temp_train = temp_train
        
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.lwta1 = LWTA()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.lwta2 = LWTA()
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        
        for name, param in self.state_dict().items():
            if (iter == 1) and ("bn1.bias" in name):
                shape_1 = param.shape[0]
            if (iter == 1) and ("bn2.bias" in name):
                shape_2 = param.shape[0]
            if (iter == 2) and ("bn1.bias" in name):
                shape_1 = param.shape[0]
            if (iter == 2) and ("bn2.bias" in name):
                shape_2 = param.shape[0]
        
        self.p_ti_1 = nn.Parameter(torch.randn(self.batch_size, shape_1, int(2048/shape_1), int(2048/shape_1)), requires_grad=True)
        nn.init.xavier_normal_(self.p_ti_1)

        self.p_ti_2 = nn.Parameter(torch.randn(self.batch_size, shape_2, int(2048/shape_2), int(2048/shape_2)), requires_grad=True)
        nn.init.xavier_normal_(self.p_ti_2)
        
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )
        
    def forward(self, x):
        out = self.lwta1(self.bn1(x), p_ti=self.p_ti_1, base_train=self.base_train, temp_train=self.temp_train, J=self.J)
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.lwta2(self.bn2(self.conv2(out)), p_ti=self.p_ti_2, base_train=self.base_train, temp_train=self.temp_train, J=self.J)
        out += shortcut
        return out
    
class PreActResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes_per_classifier=2, num_classifier=5, J=2):
        super(PreActResNet, self).__init__()
        self.J = J
        self.in_planes = 64
        self.base_train = True
        self.temp_train = 0.67
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.bn = nn.BatchNorm2d(512)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear_main = nn.ModuleList([nn.Linear(512*block.expansion, num_classes_per_classifier) for i in range(num_classifier)])
        self.linear = nn.ModuleList([nn.Linear(512*block.expansion, num_classes_per_classifier) for i in range(num_classifier)])

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        i = 1
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, iter=i, base_train=self.base_train, temp_train=self.temp_train, J=self.J))
            self.in_planes = planes * block.expansion
            i += 1
        return nn.Sequential(*layers)

    def forward(self, x, out_idx=0, main_fc=False, is_feature = False):

        assert out_idx or is_feature

        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.bn(out)
        out = self.avgpool(out)
        feature = out.view(out.size(0), -1)

        if is_feature:
            output = feature

        else:
            output_list = []

            for idx in range(out_idx):

                if main_fc:                    
                    output_list.append(self.linear_main[idx](feature))

                else:
                    output_list.append(self.linear[idx](feature))

            output = torch.cat(output_list, dim=1)        

        return output

def PreActResNet18(num_classes_per_classifier=2, num_classifier=5, J=2):
    return PreActResNet(PreActBlock, [2,2,2,2], num_classes_per_classifier=num_classes_per_classifier, num_classifier=num_classifier, J=J)

