import warnings
import os
import time

import math as m
import numpy as np
np.random.seed(1)

import keras
import keras.backend as K
from keras.models import Model
from keras.layers import Input, Dense, BatchNormalization, Activation, Flatten, Add,add, Concatenate,concatenate, Dropout
from keras.layers import Conv2D, SeparableConv2D, ZeroPadding2D, AveragePooling2D, MaxPooling2D, GlobalAveragePooling2D, LeakyReLU
from keras.regularizers import l2
from keras.initializers import Constant

from keras.datasets import cifar10, cifar100, mnist
from keras.utils.np_utils import to_categorical

def import_cifar(dataset = 10, label_mode = 'fine'):
    '''
    label_mode only relevant for CIFAR100 dataset
    '''
    if dataset == 10:
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    elif dataset == 100:
        (x_train, y_train), (x_test, y_test) = cifar100.load_data(label_mode=label_mode)

    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    m, st = x_train.mean(), x_train.std()
    x_train =x_train-m
    x_test =x_test-m
    x_train =x_train/st
    x_test =x_test/st
    
    y_train = to_categorical(y_train)
    y_test = to_categorical(y_test)
    
    return x_train, y_train, x_test, y_test

def load_task(task):
    if task in ['C10-VGG']:
        nb_classes = 10
        
        # a deep VGG style network with batchnorm
        def get_model(depth_factor=1., width_factor=1., dropout = 0., weight_decay = 0., batchnorm_train_mode = None):
            k = int(32*width_factor)
            model = VGG(input_shape = (32,32,3),
                        nbstages = 4,
                        nblayers = [int(2*depth_factor)]*4,
                        nbfilters = [1*k,2*k,4*k,8*k],
                        nbclasses = nb_classes,
                        use_batchnorm = True, use_bias = False,
                        neural_activity_control=False, neural_activity_monitoring = False,
                        softmax_norm = False,
                        dropout = dropout,
                        weight_decay = weight_decay,
                        pool = 'max',
                        dense_layers = [],
                        filename = None,
                        batchnorm_train_mode = batchnorm_train_mode)  
            return model
        
        # import data
        x_train, y_train, x_test, y_test = import_cifar(10)
        # no hierarchy of labels for CIFAR10
        suby_train, suby_test = y_train, y_test
        
    if task in ['C100coarse-WRN','C100-WRN']:
        nb_classes = 20 if task == 'C100coarse-WRN' else 100
        
        def get_model(depth_factor=1., width_factor=1., dropout=0., weight_decay = 0., batchnorm_train_mode = None):
            k = int(32*width_factor)
            model = WideResNet([1*k,2*k,4*k,8*k],
                               [int(2*depth_factor)]*4,nbstages = 4, dropout=dropout,
                               weight_decay=weight_decay,nb_classes = nb_classes,
                               use_batchnorm = True, use_bias = False,
                               neural_activity_control=False, neural_activity_monitoring = False,
                               softmax_norm = False,
                               batchnorm_train_mode = batchnorm_train_mode)
            return model
        
        # import data
        x_train, suby_train, x_test, suby_test = import_cifar(100)
        
        if task == 'C100coarse-WRN':
            # load coarse cifar100 superclasses
            _, y_train, _, y_test = import_cifar(100, label_mode = 'coarse')
        elif task == 'C100-WRN':
            y_train, y_test = suby_train, suby_test
        
          
    return get_model, x_train, y_train, suby_train, x_test, y_test, suby_test



def VGG(input_shape, nbstages, nblayers, nbfilters, nbclasses, weight_decay=0.0, use_batchnorm = True, use_bias=False, dropout=0.0, kernel_size=(3, 3), pool='average', dense_layers = [], filename = 'model_init.h5',batchnorm_train_mode = None):
    """
    VGG-style convolutional neural network
    
    nbstages is the number of spatial dimension levels
    nblayers is a list with nbstages elements containing the 
        number of convolutional layers per stage
    nbfilters is a list of size sum(nbstages) with the 
        number of filters per convolutional layer in a stage
    
    kernel_constraint only applied on convolutional layers
    
    uses batchnorm after each Convolutional layer (non-linearity included)
    
    pool is either 'max' or 'average'
    """        
    assert not len(nblayers) != nbstages, 'nblayers should contain one element per stage.'
    assert not len(nbfilters) != nbstages, 'nbfilters should contain one element per stage.'
    if neural_activity_control:
        assert not len(threshold_percentiles) != nbstages, 'threshold_percentiles should contain one element per stage.'
    
    regularizer = None
    if weight_decay > 0.0:
        regularizer = l2(weight_decay)
    
    input_model = Input(shape=input_shape)
    x = input_model
    
    layer_counter = 0
    for s in range(nbstages):
        for l in range(nblayers[s]):
            x = Conv2D(nbfilters[s], kernel_size=kernel_size, padding='same', name='stage' + str(s) + '_layer' + str(l) + '_conv',
              kernel_regularizer=regularizer,
              use_bias=use_bias)(x)
            
            if use_batchnorm:
                x = BatchNormalization(axis=-1, momentum=0.99, 
                                       center = not neural_activity_control,
                                       scale = not neural_activity_control,
                                       name=('stage' + str(s) + '_layer' + str(l) + '_batch'))(x,training=batchnorm_train_mode)
            x = Activation('relu',name=('stage' + str(s) + '_layer' + str(l) + '_relu'))(x)
            
            if dropout > 0.0:
                x = Dropout(dropout)(x)
            layer_counter += 1

        if s != nbstages - 1:
            if pool == 'max':
                x = MaxPooling2D((2, 2), strides=(2, 2), name=('stage' + str(s) + '_pool'))(x)
            elif pool == 'average':
                x = AveragePooling2D((2, 2), strides=(2, 2), name=('stage' + str(s) + '_pool'))(x)
    
    if len(dense_layers)==0:
        x = GlobalAveragePooling2D(name='global_pool')(x)
    else:
        x = Flatten()(x)
        for i,units in enumerate(dense_layers):
            x = Dense(units, name='dense_' + str(i) + '_kernel',
                  kernel_regularizer=regularizer,
                  use_bias=use_bias)(x)

            if use_batchnorm:
                x = BatchNormalization(axis=-1, momentum=0.99, 
                                       center = not neural_activity_control,
                                       scale = not neural_activity_control,
                                       name='dense_' + str(i) + '_batch')(x,training=batchnorm_train_mode)
            x = Activation('relu',name='dense_' + str(i)+ '_relu')(x)
            
            x = Dropout(dropout)(x)

    x = Dense(nbclasses, name='last_dense', use_bias=use_bias, kernel_regularizer=regularizer)(x)
    
    x = Activation('softmax', name='predictions')(x)
    
    model = Model(input_model,x)
    
    if filename is not None:
        if not os.path.exists("saved_weights"):
            os.makedirs("saved_weights")

        if os.path.exists("saved_weights/"+filename):
            model.load_weights("saved_weights/"+filename)
        else:
            model.save_weights("saved_weights/"+filename)
    return model

#==============================================================================
# WideResnet model: https://arxiv.org/abs/1605.07146
#==============================================================================
def block(inp,nbfilters,dropout,weight_decay,channel_axis,subsample = (1,1), 
          use_batchnorm=True, use_bias = True,
          batchnorm_train_mode = None): 
        
    x = inp
    
    for i in [1,2]:
        if use_batchnorm:
            x = BatchNormalization(axis=-1, momentum=0.99, 
                                   center = not neural_activity_control,
                                   scale = not neural_activity_control)(x,training=batchnorm_train_mode)
        x = Activation('relu')(x)
                    
        if dropout>0. and i==2:
            x = Dropout(dropout)(x)
        
        x = ZeroPadding2D((1,1))(x)
        if subsample is not None and i==1:
            x = Conv2D(nbfilters,(3,3),strides=subsample,kernel_regularizer=l2(weight_decay), use_bias = use_bias)(x)
        else:
            x = Conv2D(nbfilters,(3,3),kernel_regularizer=l2(weight_decay), use_bias = use_bias)(x)
    
    if subsample==(1,1) and inp._keras_shape[channel_axis] == nbfilters: # checks for subsampling or change in nb of filters
        return add([x,inp])
    else:
        return add([x,Conv2D(nbfilters,(1,1),strides = subsample,kernel_regularizer=l2(weight_decay), use_bias = use_bias)(inp)])

def stage(x, nbfilters, N, dropout, weight_decay, channel_axis, subsample = True, use_batchnorm=True, use_bias = True, batchnorm_train_mode = None):
    if subsample:
        x = block(x,nbfilters, dropout, weight_decay, channel_axis, 
                  subsample = (2,2), use_batchnorm=use_batchnorm,
                  use_bias = use_bias, batchnorm_train_mode=batchnorm_train_mode)
    else: 
        x = block(x,nbfilters, dropout, weight_decay, channel_axis, use_batchnorm=use_batchnorm, 
                  use_bias = use_bias, batchnorm_train_mode=batchnorm_train_mode)
    for i in range(1,N):
        x = block(x,nbfilters, dropout, weight_decay, channel_axis, use_batchnorm=use_batchnorm, 
                  use_bias = use_bias, batchnorm_train_mode=batchnorm_train_mode)
    return x

# Wide ResNet architecture
# Contains 3 stages
# arguments are lists with one parameter per stage
# input conv nbfilter is always 16
def WideResNet(nbfilters,nbblocks,nb_classes,nbstages=3,weight_decay=0., use_batchnorm=True, use_bias = False,dropout=0., batchnorm_train_mode=None):
    
    assert not len(nbfilters) != nbstages, 'nblayers should contain one element per stage.'
    assert not len(nbblocks) != nbstages, 'nbfilters should contain one element per stage.'
        
    if K.image_data_format() == 'channels_last':
        input_model = Input(shape = (32,32,3))
        channel_axis = -1
    elif K.image_data_format() == 'channels_first':
        input_model = Input(shape = (3,32,32))
        channel_axis = 1
    
    # input convolution
    x = ZeroPadding2D((1,1))(input_model)
    x = Conv2D(16, (3, 3),kernel_regularizer=l2(weight_decay), use_bias = use_bias)(x)
    
    for s in range(nbstages):
        subsample = False if s == 0 else True
        x = stage(x,nbfilters[s],nbblocks[s], dropout, weight_decay, channel_axis, subsample = subsample, 
                  use_batchnorm=use_batchnorm, use_bias = use_bias, batchnorm_train_mode=batchnorm_train_mode)
    
    if use_batchnorm:
        x = BatchNormalization(axis = channel_axis,center = not neural_activity_control, scale = not neural_activity_control)(x,training=batchnorm_train_mode)
    x = Activation('relu')(x)
    
    x = GlobalAveragePooling2D()(x)
    x = Dense(nb_classes,kernel_regularizer=l2(weight_decay), use_bias = use_bias)(x)
    
    x = Activation('softmax')(x)
    
    return Model(input_model,x)