import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow import keras
import tensorflow_addons as tfa


num_classes_pretrain = 100
num_classes=10
input_shape = (32, 32, 3)

#(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
x_train_raw=np.load("cifar10/cifar10/training_data.npy")
y_train_raw=np.load("cifar10/cifar10/training_label.npy")
x_test_raw=np.load("cifar10/cifar10/testing_data.npy")
y_test_raw=np.load("cifar10/cifar10/testing_label.npy")
_=np.load('imagenet_plants.npz')
background=_['arr_0']




def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches


import matplotlib.pyplot as plt




class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

import random

def mask_back(b,c, rho, sz, data):

    msk1=[]
    for i in range(32):
        k=32*i
        if i>=b and i<b+sz:
            list1=list(np.arange(c)+k)+list(np.arange(c+sz,32)+k)
            list2=list(np.arange(c,c+sz)+k)
            list3=list2+random.sample(list1, int(len(list1)*(1-rho)))
            msk1=msk1+list3
        else:
            msk1=msk1+random.sample(list(np.arange(32)+k),int((1-rho)*32))
    msk1.sort()
    mskk=np.zeros(32*32)
    mskk[msk1]=1
    matrix_mask=np.array(mskk).reshape([32,32,1])
    tensor_mask=matrix_mask*np.ones([1,3])
    return np.multiply(data, tensor_mask)

def add_background(x_train_raw,x_test_raw,rho,sz, ratio):
    x_train = np.zeros(x_train_raw.shape)
    x_test=np.zeros(x_test_raw.shape)

    for i in range(x_train_raw.shape[0]):
        a = random.randint(0, background.shape[0] - 1)
        if rho<1:
            x_train[i,:,:,:]=background[a,:,:,:]
        elif rho==1:
            x_train[i, :, :, :] = np.zeros([32, 32, 3])
        b = random.randint(0, 32-sz)
        c=random.randint(0,32-sz)
        dt = x_train_raw[i, :, :, :] / 255
        #x_train[i, b:b + 16, b:b + 16, :] = dt[0:dt.shape[0]:2, 0:dt.shape[1]:2, :]
        x_train[i, b:b +sz, c:c +sz, :] = tf.image.resize(tf.convert_to_tensor([dt]),size=(sz,sz))
        x_train[i,:,:,:]=mask_back(b,c,rho,sz, x_train[i,:,:,:])
    for i in range(x_test_raw.shape[0]):
        a = random.randint(0, background.shape[0] - 1)
        if rho<1:
            x_test[i,:,:,:]=background[a,:,:,:]
        elif rho==1:
            x_test[i, :, :, :] = np.zeros([32, 32, 3])
        b = random.randint(0, 32-sz)
        c= random.randint(0,32-sz)
        dt = x_test_raw[i, :, :, :] / 255
        #x_test[i, b:b + 16, b:b + 16, :] = dt[0:dt.shape[0]:2, 0:dt.shape[1]:2, :]
        x_test[i, b:b +sz, c:c + sz, :] = tf.image.resize(tf.convert_to_tensor([dt]), size=(sz, sz))
        x_test[i,:,:,:]=mask_back(b,c,rho,sz, x_test[i,:,:,:])
    return x_train, x_test

def pretrain_vit_classifier():
    inputs = layers.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size)(augmented)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classify outputs.
    logits = layers.Dense(num_classes_pretrain)(features)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

def create_vit_classifier(pretrain):
    inputs = layers.Input(shape=input_shape)
    model_0=keras.Model(inputs=pretrain.input, outputs=pretrain.layers[-2].output)
    x=model_0(inputs)
    logits = layers.Dense(num_classes)(x)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model



def run_experiment(model, x_train, y_train, x_test, y_test):
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    checkpoint_filepath = "./tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return history

if __name__ == "__main__":
    N_per_class=[2000,3000,4000,5000]
    RHO=[0.1675, 0.5144, 0.7842, 1]
    for n_per_class in N_per_class:
        for rho in RHO:
            sz=26
            x_train_un, x_test_un=add_background(x_train_raw, x_test_raw, rho, sz, ratio=0.1)

            list_train=[]
            for i in range(num_classes):
                ind=np.where(y_train_raw==np.array([i]))[0]
                index=random.sample(list(ind),n_per_class)
                list_train=list_train+index
            train_list=random.shuffle(list_train)
            x_train=x_train_un[list_train]
            y_train=y_train_raw[list_train]

            list_test=[]
            for i in range(num_classes):
                ind=np.where(y_test_raw==np.array([i]))[0]
                #index=random.sample(list(ind),n_per_class)
                list_test=list_test+list(ind)
            test_list=random.shuffle(list_test)
            x_test=x_test_un[list_test]
            y_test=y_test_raw[list_test]
            #y_test=y_test_raw

            #x_train=x_train_un
            #x_test=x_test_raw
            learning_rate = 0.001
            weight_decay = 0.00001
            batch_size = 256
            num_epochs = 200
            image_size = 72 # We'll resize input images to this size
            patch_size = 6  # Size of the patches to be extract from the input images
            num_patches = (image_size // patch_size) ** 2
            projection_dim = 64
            num_heads = 4
            transformer_units = [
                projection_dim * 2,
                projection_dim,
            ]  # Size of the transformer layers
            transformer_layers = 5
            mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier

            data_augmentation = keras.Sequential(
                [
                    layers.Normalization(),
                    layers.Resizing(image_size, image_size),
                    layers.RandomFlip("horizontal"),
                    layers.RandomRotation(factor=0.05),
                    layers.RandomZoom(
                        height_factor=0.2, width_factor=0.2
                    ),
                ],
                name="data_augmentation",
            )
            # Compute the mean and the variance of the training data for normalization.
            data_augmentation.layers[0].adapt(x_train)
            plt.figure(figsize=(4, 4))
            image = x_train[np.random.choice(range(x_train.shape[0]))]
            plt.imshow(image.astype("uint8"))
            plt.axis("off")

            resized_image = tf.image.resize(
                tf.convert_to_tensor([image]), size=(image_size, image_size)
            )
            patches = Patches(patch_size)(resized_image)
            print(f"Image size: {image_size} X {image_size}")
            print(f"Patch size: {patch_size} X {patch_size}")
            print(f"Patches per image: {patches.shape[1]}")
            print(f"Elements per patch: {patches.shape[-1]}")

            n = int(np.sqrt(patches.shape[1]))
            plt.figure(figsize=(4, 4))

            for i, patch in enumerate(patches[0]):
                ax = plt.subplot(n, n, i + 1)
                patch_img = tf.reshape(patch, (patch_size, patch_size, 3))
                plt.imshow(patch_img.numpy().astype("uint8"))
                plt.axis("off")

            vit_classifier = pretrain_vit_classifier()
            vit_classifier.load_weights("./pretrain_cifar100/checkpoint")
            new_model=create_vit_classifier(vit_classifier)
            history = run_experiment(new_model, x_train, y_train, x_test, y_test)
            print('N: ',n_per_class,'| rho: ',rho)