import tensorflow as tf
import torch
# tf.enable_eager_execution()
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from PIL import Image
# import tensorflow.keras.backend as K
import numpy as np
import re
import pdb

def WRN(): #, params):
    '''Bottleneck WRN-50-2 model definition
    '''
    def tr(v):
        if v.ndim == 4:
            return v.transpose(2,3,1,0)
        elif v.ndim == 2:
            return v.transpose()
        return v
    # params = {k: tf.constant(tr(v)) for k, v in params.items()}
    params = {k: v.numpy() for k,v in torch.load('wide-resnet-50-2-export-5ae25d50.pth').items()}
    # with open('wrn.npy', 'rb') as f:
    #     params = np.load(f, allow_pickle=True).item()
    
    # def conv2d(x, params, name, stride=1, padding=0):
    #     pdb.set_trace()
    #     x = tf.pad(x, [[0,0],[padding,padding],[padding,padding],[0,0]])
    #     z = tf.nn.conv2d(x, params['%s.weight'%name], [1,stride,stride,1], padding='VALID')
    #     if '%s.bias'%name in params:
    #         return tf.nn.bias_add(z, params['%s.bias'%name])
    #     else:
    #         return z
    
    # def group(input, params, base, stride, n):
    #     o = input
    #     for i in range(0,n):
    #         b_base = ('%s.block%d.conv') % (base, i)
    #         x = o
    #         o = conv2d(x, params, b_base + '0')
    #         o = tf.nn.relu(o)
    #         o = conv2d(o, params, b_base + '1', stride=i==0 and stride or 1, padding=1)
    #         o = tf.nn.relu(o)
    #         o = conv2d(o, params, b_base + '2')
    #         if i == 0:
    #             o += conv2d(x, params, b_base + '_dim', stride=stride)
    #         else:
    #             o += x
    #         o = tf.nn.relu(o)
    #     return o
    
    # Keras versions
    def conv2d(x, params, name, stride=1, padding=0):
        x = layers.ZeroPadding2D(padding=padding)(x)
        # x = tf.pad(x, [[0,0],[padding,padding],[padding,padding],[0,0]])
        # z = layers.Conv2D(params['%s.weight'%name][0], params['%s.weight'%name][-1], strides=stride, padding='valid')(x) # was loading weights directly
        # correct # z = layers.Conv2D(params['%s.weight'%name].shape[0], params['%s.weight'%name].shape[-1], strides=stride, padding='valid')(x)
        # z = tf.nn.conv2d(x, params['%s.weight'%name], [1,stride,stride,1], padding='VALID')
        if '%s.bias'%name in params:
            # pdb.set_trace()
            # return layers.Add()(z, params['%s.bias'%name]) 
            z = layers.Conv2D(params['%s.weight'%name].shape[0], params['%s.weight'%name].shape[-1], strides=stride, padding='valid')(x)
            return z
        else:
            z = layers.Conv2D(params['%s.weight'%name].shape[0], params['%s.weight'%name].shape[-1], strides=stride, padding='valid', use_bias=False)(x)
            return z

    def group(input, params, base, stride, n):
        o = input
        for i in range(0,n):
            b_base = ('%s.block%d.conv') % (base, i)
            x = o
            o = conv2d(x, params, b_base + '0')
            o = layers.ReLU()(o)
            # o = tf.nn.relu(o)
            o = conv2d(o, params, b_base + '1', stride=i==0 and stride or 1, padding=1)
            o = layers.ReLU()(o)
            # o = tf.nn.relu(o)
            o = conv2d(o, params, b_base + '2')
            if i == 0:
                o += conv2d(x, params, b_base + '_dim', stride=stride)
            else:
                o += x
            o = layers.ReLU()(o)
            # o = tf.nn.relu(o)
        return o
    
    # determine network size by parameters
    blocks = [sum([re.match('group%d.block\d+.conv0.weight'%j, k) is not None
                   for k in params.keys()]) for j in range(4)]
    # # pdb.set_trace()
    # o = conv2d(inputs, params, 'conv0', 2, 3)
    # o = tf.nn.relu(o)
    # o = tf.pad(o, [[0,0], [1,1], [1,1], [0,0]])
    # o = tf.nn.max_pool(o, ksize=[1,3,3,1], strides=[1,2,2,1], padding='VALID')
    # o_g0 = group(o, params, 'group0', 1, blocks[0])
    # o_g1 = group(o_g0, params, 'group1', 2, blocks[1])
    # o_g2 = group(o_g1, params, 'group2', 2, blocks[2])
    # o_g3 = group(o_g2, params, 'group3', 2, blocks[3])
    # o = tf.nn.avg_pool(o_g3, ksize=[1,7,7,1], strides=[1,1,1,1], padding='VALID')
    # o = tf.reshape(o, [-1, 2048])
    # o = tf.matmul(o, params['fc.weight']) + params['fc.bias']
    # return o

    ''' Keras model '''
    inputs = layers.Input([224,224,3])
    o = conv2d(inputs, params, 'conv0', 2, 3)
    o = layers.ReLU()(o)
    o = layers.ZeroPadding2D(padding=1)(o)
    o = layers.MaxPool2D(pool_size=(3, 3), strides=(2,2), padding='valid')(o)
    o_g0 = group(o, params, 'group0', 1, blocks[0])
    o_g1 = group(o_g0, params, 'group1', 2, blocks[1])
    o_g2 = group(o_g1, params, 'group2', 2, blocks[2])
    o_g3 = group(o_g2, params, 'group3', 2, blocks[3])
    o = layers.AveragePooling2D(pool_size=(7, 7), strides=(1,1), padding='valid')(o_g3)
    o = layers.Reshape(target_shape=[2048])(o)
    # pdb.set_trace()
    # o = layers.Dot(axes=[1,0])([o, K.transpose(params['fc.weight'])])
    # o = o + params['fc.bias']
    o = layers.Dense(1000)(o)

    model = models.Model(inputs, o)

    # Loading weights
    params_sorted = {k:v for k,v in sorted(params.items())}

    params_sorted.pop('fc.weight')
    params_sorted.pop('fc.bias')

    vals = list(params_sorted.values())

    idx = 0
    vals_idx = 0
    while idx < len(model.layers)-1:
        while len(model.layers[idx].trainable_weights) == 0 and idx<len(model.layers)-1:
            idx+=1
        if idx == len(model.layers)-1:
            break
        # pdb.set_trace()
        # model.layers[idx].set_weights([np.transpose(vals[vals_idx+1], (3,2,1,0)), vals[vals_idx]]) #old, wrong (61.0 acc)
        model.layers[idx].set_weights([np.transpose(vals[vals_idx+1], (2,3,1,0)), vals[vals_idx]]) # correct, 96.34 acc
        vals_idx+=2
        idx+=1

    model.layers[-1].set_weights([np.transpose(params['fc.weight'],(1,0)), params['fc.bias']])
    return model
    


if __name__ == '__main__':
    # params = torch.model_zoo.load_url('https://s3.amazonaws.com/modelzoo-networks/wide-resnet-50-2-export-5ae25d50.pth')
    # params = {k: v.numpy() for k,v in torch.load('wide-resnet-50-2-export-5ae25d50.pth').items()}
    # inputs_tf = tf.placeholder(tf.float32, shape=[None,224,224,3])

    def normalize_WRN(x):
            # x = Image.fromarray(np.uint8(x))
            # x = x.resize((232,232))
            # x = np.array(x.crop((4,4,4,4)))
            # pdb.set_trace()
            x/=255.
            mean=[0.485, 0.456, 0.406]
            variance=[0.229, 0.224, 0.225]
            return (x-mean)/variance
    test_data_preprocessor = ImageDataGenerator(
        preprocessing_function=normalize_WRN, # for WideResNet only !!!, # otherwise preprocess_input, 
        samplewise_center=False,
    )
    val_ds = test_data_preprocessor.flow_from_directory(
        # '/scratch/imtiaz.t/PycharmProjects/recent-FW/FrankWolfe.jl/fw-rde/ImageNet/data/val_data',
        '/home/toobaml/PycharmProjects/GreedyFool/val_data',
        # validation_split=0.2,
        # subset="validation",
        target_size=(224, 224),
        batch_size=1,
        shuffle=False,
        # keep_aspect_ratio=True
        )

    # out = WRN(inputs_tf) #, params)

    # sess = tf.Session()
    # y_tf = sess.run(out, feed_dict={inputs_tf: inputs_tf.permute(0,2,3,1).numpy()})

    # norm = tf.keras.Sequential()
    # norm.add(tf.keras.layers.Rescaling(1./255))
    # norm.add(tf.keras.layers.Normalization(axis=-1, mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225]))
    # val_ds = val_ds.map(lambda x, y: (norm(x), y))

    model = WRN()

    # test_accuracy = tf.keras.metrics.Accuracy()
    acc = 0
    for idx, (x, y) in enumerate(val_ds):
        print(idx)
        # if idx<500:
        #     continue
        if idx >=len(val_ds):
            break

        logits = model.predict(x, steps=1)
        acc += int(np.array_equal(np.argmax(y), np.argmax(logits)))

    # print('acc= ', acc/500.)
    print('Accuracy = ', acc/len(val_ds))