import tensorflow as tf
from tensorflow.keras.layers import Dense, Layer, MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D, UpSampling2D, Input, concatenate, Activation, add, BatchNormalization
from tensorflow.keras import Model
from models.utils import Conv, Shuffle
import numpy as np

class DataLayer(Layer):
    def __init__(self, num_blocks=3, num_feat_base=29, kernel_size=3, pool_size=2, dropout_rate=0.0, weight_decay=0.0, conv_type='std', downsample=True, **kwargs):
        super(DataLayer, self).__init__(**kwargs)
        self.num_blocks = num_blocks
        self.num_feat_base = num_feat_base
        self.kernel_size = kernel_size
        self.pool_size = pool_size
        self.dropout_rate = dropout_rate
        self.weight_decay = weight_decay
        self.conv_type = conv_type
        self.downsample = downsample

        if conv_type == 'mobileV2':
            num_feat_res = num_feat_base // 2
        else:
            num_feat_res = num_feat_base

        # layers
        if downsample:
            self.pool1 = MaxPooling2D(pool_size=pool_size)
            self.pool2 = MaxPooling2D(pool_size=pool_size)
            self.conv1 = Conv(num_feat_base, kernel_size=kernel_size, strides=pool_size, dropout_rate=0, weight_decay=weight_decay, name='conv_1', with_bn=True)
            self.conv2 = Conv(num_feat_base, kernel_size=kernel_size, strides=pool_size, dropout_rate=0, weight_decay=weight_decay, name='conv_2', with_bn=True)
        else:
            self.conv1 = Conv(num_feat_base, kernel_size=kernel_size, dropout_rate=0, weight_decay=weight_decay, name='conv_1', with_bn=True)

        self.convTrans = []
        self.poolList = []
        for indexBlock in range(num_blocks):
            name = 'trans_conv_b' + str(indexBlock)
            self.poolList.append(AveragePooling2D(pool_size=int(np.power(2, indexBlock))))
            self.convTrans.append(Conv(num_feat_res, kernel_size=kernel_size, dropout_rate=0, weight_decay=weight_decay, name=name))

    def call(self, inputs):
        if self.downsample:
            inputsDown = self.pool1(inputs)
            output = self.conv1(inputs)
            output = concatenate([inputsDown, output], axis=3)
            outputDown = self.pool2(output)
            output = self.conv2(output)
            output = concatenate([outputDown, output], axis=3)
        else:
            output = self.conv1(inputs)
            output = concatenate([inputs, output], axis=3)

        dataLayerList = []
        for indexBlock in range(self.num_blocks):
            if indexBlock == 0:
                outputTrans = output
            else:
                outputTrans = self.poolList[indexBlock](output)
            outputTrans = self.convTrans[indexBlock](outputTrans)
            dataLayerList.append(outputTrans)

        return dataLayerList

    def get_config(self):
        config = super(DataLayer, self).get_config()
        config.update({'num_blocks': self.num_blocks,
            'num_feat_base': self.num_feat_base,
            'kernel_size': self.kernel_size,
            'pool_size': self.pool_size,
            'dropout_rate': self.dropout_rate,
            'weight_decay': self.weight_decay,
            'conv_type': self.conv_type,
            'downsample': self.downsample})
        return config

class BuildingBlock(Layer):
    def __init__(self, num_blocks=3, num_repeats=1, num_feat_base=29, kernel_size=3, pool_size=2, dropout_rate=0.0, weight_decay=0.0, conv_type='std', **kwargs):
        super(BuildingBlock, self).__init__(**kwargs)
        self.num_blocks = num_blocks
        self.num_repeats = num_repeats
        self.num_feat_base = num_feat_base
        self.kernel_size = kernel_size
        self.pool_size = pool_size
        self.dropout_rate = dropout_rate
        self.weight_decay = weight_decay
        self.conv_type = conv_type

        if conv_type == 'mobileV2':
            num_feat_res = num_feat_base // 2
        else:
            num_feat_res = num_feat_base

        # layers
        # trans
        self.trans_conv = []
        self.trans_act = []
        for indexBlock in range(num_blocks): # TODO: make sure parameters are not counted for first block
            name = 'trans_conv_b' + str(indexBlock).zfill(2)
            self.trans_conv.append(Conv(num_feat_res, kernel_size=kernel_size, dropout_rate=0, weight_decay=weight_decay, with_act=False, name=name, conv_type='std'))
            self.trans_act.append(Activation('relu', name=name + '_act'))

        # res
        self.conv_1 = []
        self.conv_2 = []
        self.act = []
        for indexBlock in range(num_blocks):
            self.conv_1.append([])
            self.conv_2.append([])
            self.act.append([])
            num_convs = self._get_num_convs(indexBlock)
            for k in range(num_convs):
                if conv_type == 'std': # first conv of std ResNet block
                    name = 'res_conv_1_b' + str(indexBlock).zfill(2) + '_k' + str(k).zfill(2)
                    self.conv_1[indexBlock].append(Conv(num_feat_res, kernel_size=kernel_size, dropout_rate=dropout_rate, weight_decay=weight_decay, name=name, conv_type='std'))
                name = 'res_conv_2_b' + str(indexBlock).zfill(2) + '_k' + str(k).zfill(2)
                self.conv_2[indexBlock].append(Conv(num_feat_res, kernel_size=kernel_size, dropout_rate=dropout_rate, weight_decay=weight_decay, with_act=False, name=name, conv_type=conv_type))
                self.act[indexBlock].append(Activation('relu', name=name + '_act'))

    def _get_num_convs(self, indexBlock):
        scaling_repeat = False
        if scaling_repeat:
            num_convs = (indexBlock // 2 + 1) * self.num_repeats
        else:
            num_convs = self.num_repeats
        return num_convs

    def call(self, downList, upConcatList, with_trans=True):
        # trans
        downListNew = [] # required for serializing models <- not used in the end
        if with_trans:
            for indexBlock in range(self.num_blocks):
                inputs = upConcatList[indexBlock]
                output = self.trans_conv[indexBlock](inputs)
                downListNew.append(add([downList[indexBlock], output]))
                downListNew[indexBlock] = self.trans_act[indexBlock](downListNew[indexBlock])
        else:
            for indexBlock in range(self.num_blocks):
                downListNew.append(downList[indexBlock])

        # res
        for indexBlock in range(self.num_blocks):
            num_convs = self._get_num_convs(indexBlock)

            # resnet block
            for k in range(num_convs):
                inputs = downListNew[indexBlock]
                if self.conv_type == 'std': # first conv of std ResNet block
                    inputs = self.conv_1[indexBlock][k](inputs)

                output = self.conv_2[indexBlock][k](inputs)
                downListNew[indexBlock] = add([downListNew[indexBlock], output])
                downListNew[indexBlock] = self.act[indexBlock][k](downListNew[indexBlock])

        return downListNew

    def get_config(self):
        config = super(BuildingBlock, self).get_config()
        config.update({'num_blocks': self.num_blocks,
            'num_repeats': self.num_repeats,
            'num_feat_base': self.num_feat_base,
            'kernel_size': self.kernel_size,
            'pool_size': self.pool_size,
            'dropout_rate': self.dropout_rate,
            'weight_decay': self.weight_decay,
            'conv_type': self.conv_type})
        return config

class BN(Layer):
    def __init__(self, num_blocks=3, **kwargs):
        super(BN, self).__init__(**kwargs)
        self.num_blocks = num_blocks

        self.bn_list = []
        for indexBlock in range(num_blocks):
            name = 'bn_b' + str(indexBlock).zfill(2)
            self.bn_list.append(BatchNormalization(name=name))

    def call(self, downList):
        downListNew = []
        for indexBlock in range(self.num_blocks):
            downListNew.append(self.bn_list[indexBlock](downList[indexBlock]))
        return downListNew

    def get_config(self):
        config = super(BN, self).get_config()
        config.update({'num_blocks': self.num_blocks})
        return config

class Permute(Layer):
    def __init__(self, num_blocks=3, **kwargs):
        super(Permute, self).__init__(**kwargs)
        self.num_blocks = num_blocks

        self.shuffle_list = []
        for indexBlock in range(num_blocks):
            name = 'shuffle_b' + str(indexBlock).zfill(2)
            self.shuffle_list.append(Shuffle(name=name))

    def call(self, downList):
        downListNew = []
        for indexBlock in range(self.num_blocks):
            downListNew.append(self.shuffle_list[indexBlock](downList[indexBlock]))
        return downListNew

    def get_config(self):
        config = super(Permute, self).get_config()
        config.update({'num_blocks': self.num_blocks})
        return config

class Mixer(Layer):
    def __init__(self, num_blocks=3, pool_size=2, **kwargs):
        super(Mixer, self).__init__(**kwargs)
        self.num_blocks = num_blocks
        self.pool_size = pool_size

        # mix
        self.up = []
        self.pool = []
        for indexBlock in range(num_blocks):
            self.up.append([])
            self.pool.append([])
            for indexBlockOther in range(num_blocks):
                pool_size_long = int(np.power(pool_size, np.abs(indexBlockOther - indexBlock)))
                self.up[indexBlock].append(UpSampling2D(size=pool_size_long, interpolation='bilinear'))
                self.pool[indexBlock].append(AveragePooling2D(pool_size=pool_size_long))

    def call(self, downList):
        # mix
        upConcatList = []
        for indexBlock in range(self.num_blocks):
            upList = []
            for indexBlockOther in range(self.num_blocks):
                if indexBlockOther == indexBlock:
                    upList.append(downList[indexBlock])
                elif indexBlockOther > indexBlock:
                    otherBlock = self.up[indexBlock][indexBlockOther](downList[indexBlockOther])
                    upList.append(otherBlock)
                else:
                    otherBlock = self.pool[indexBlock][indexBlockOther](downList[indexBlockOther])
                    upList.append(otherBlock)
            if len(upList) > 1:
                upListConcat = concatenate(upList, axis=3)
            else:
                upListConcat = upList[0]
            upConcatList.append(upListConcat)

        return upConcatList

    def get_config(self):
        config = super(Mixer, self).get_config()
        config.update({'num_blocks': self.num_blocks,
            'pool_size': self.pool_size})
        return config

class ClassLayer(Layer):
    def __init__(self, num_feat_base=29, kernel_size=3, num_classes=11, pool_size=2, dropout_rate=0.0, weight_decay=0.0, downsample=True, global_pool=False, **kwargs):
        super(ClassLayer, self).__init__(**kwargs)
        self.num_feat_base = num_feat_base
        self.kernel_size = kernel_size
        self.num_classes = num_classes
        self.pool_size = pool_size
        self.dropout_rate = dropout_rate
        self.weight_decay = weight_decay
        self.downsample = downsample
        self.global_pool = global_pool

        # layers
        if downsample:
            # self.conv1 = Conv(num_feat_base, kernel_size=kernel_size, dropout_rate=0, weight_decay=weight_decay, name='class_conv_1', with_bn=True)
            self.conv = Conv(num_classes, kernel_size=1, dropout_rate=0, weight_decay=weight_decay, with_act=False, name='class_conv')
            self.up = UpSampling2D(size=int(np.power(pool_size, 2)), interpolation='bilinear', name='class_up')
        if global_pool:
            # self.conv1 = Conv(num_feat_base, kernel_size=kernel_size, strides=pool_size, dropout_rate=0, weight_decay=weight_decay, name='class_conv_1', with_bn=True)
            # self.conv2 = Conv(num_feat_base, kernel_size=kernel_size, strides=pool_size, dropout_rate=0, weight_decay=weight_decay, name='class_conv_2', with_bn=True)
            self.pool = GlobalAveragePooling2D()
            self.dense = Dense(num_classes, use_bias=False)

    def call(self, inputs):
        output = inputs
        if self.global_pool:
            # output = self.conv1(output)
            # output = self.conv2(output)
            output = self.pool(output)
            output = self.dense(output)
        else:
            # output = self.conv1(output)
            output = self.conv(output)
            if self.downsample:
                output = self.up(output)

        return output

    def get_config(self):
        config = super(ClassLayer, self).get_config()
        config.update({'num_feat_base': self.num_feat_base,
            'kernel_size': self.kernel_size,
            'num_classes': self.num_classes,
            'pool_size': self.pool_size,
            'dropout_rate': self.dropout_rate,
            'weight_decay': self.weight_decay,
            'downsample': self.downsample,
            'global_pool': global_pool})
        return config

class IterNetwork(Model):
    def __init__(self, num_classes=11, num_blocks=3, num_repeats=1, num_iterations=17, num_feat_base=29, kernel_size=3, pool_size=2, dropout_rate=0.0, weight_decay=0.0, conv_type='std', downsample=True, global_pool=False, share_parameters=False, shuffle=False, **kwargs):
        super(IterNetwork, self).__init__(**kwargs)
        self.num_classes = num_classes
        self.num_blocks = num_blocks
        self.num_repeats = num_repeats
        self.num_iterations = num_iterations
        self.num_feat_base = num_feat_base
        self.kernel_size = kernel_size
        self.pool_size = pool_size
        self.dropout_rate = dropout_rate
        self.weight_decay = weight_decay
        self.conv_type = conv_type
        self.downsample = downsample
        self.global_pool = global_pool
        self.share_parameters = share_parameters
        self.shuffle = shuffle

        self.data_layer = DataLayer(num_blocks=num_blocks, num_feat_base=num_feat_base, kernel_size=kernel_size, pool_size=pool_size, dropout_rate=dropout_rate, weight_decay=weight_decay, conv_type=conv_type, downsample=downsample)
        self.building_block_list = []
        self.classification_layer_list = []
        self.bn_list = []
        self.shuffle_list = []
        self.mixer = Mixer(num_blocks=num_blocks, pool_size=pool_size)
        if share_parameters:
            bb = BuildingBlock(num_blocks=num_blocks, num_repeats=num_repeats, num_feat_base=num_feat_base, kernel_size=kernel_size, pool_size=pool_size, dropout_rate=dropout_rate, weight_decay=weight_decay, conv_type=conv_type)
        for indexLayer in range(num_iterations):
            if share_parameters:
                self.building_block_list.append(bb)
            else:
                self.building_block_list.append(BuildingBlock(num_blocks=num_blocks, num_repeats=num_repeats, num_feat_base=num_feat_base, kernel_size=kernel_size, pool_size=pool_size, dropout_rate=dropout_rate, weight_decay=weight_decay, conv_type=conv_type))
            self.bn_list.append(BN(num_blocks=num_blocks))
            if self.shuffle:
                self.shuffle_list.append(Permute(num_blocks=num_blocks))
            self.classification_layer_list.append(ClassLayer(num_feat_base=num_feat_base, kernel_size=kernel_size, num_classes=num_classes, pool_size=pool_size, dropout_rate=dropout_rate, weight_decay=weight_decay, downsample=downsample, global_pool=global_pool))

    def call(self, inputs):
        downList = self.data_layer(inputs)

        upConcatList = 3 * [None]
        outputList = []
        for indexLayer in range(self.num_iterations):
            downList = self.building_block_list[indexLayer](downList, upConcatList, with_trans=indexLayer>0)
            downList = self.bn_list[indexLayer](downList)
            if self.shuffle:
                downList = self.shuffle_list[indexLayer](downList)
            upConcatList = self.mixer(downList)
            outputList.append(self.classification_layer_list[indexLayer](upConcatList[0]))

        return outputList

    def get_config(self):
        config = super(IterNetwork, self).get_config()
        config.update({'num_classes': self.num_classes,
            'num_blocks': self.num_blocks,
            'num_repeats': self.num_repeats,
            'num_iterations': self.num_iterations,
            'num_feat_base': self.num_feat_base,
            'kernel_size': self.kernel_size,
            'pool_size': self.pool_size,
            'dropout_rate': self.dropout_rate,
            'weight_decay': self.weight_decay,
            'conv_type': self.conv_type,
            'downsample': self.downsample,
            'global_pool': self.global_pool,
            'share_parameters': self.share_parameters,
            'shuffle': self.shuffle})
        return config
