# -*- coding: utf-8 -*-
"""positional_bias.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/10WX2cJ5HuIuFdBQQa3nwcxgsf3EiL0Ie

# Importing the libraries
"""

import os
import sys

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.losses import categorical_crossentropy
import matplotlib.pyplot as plt
from tabulate import tabulate

conf_path = os.getcwd()
sys.path.append(conf_path)

from src.geo_conv import GeoConv2D
from src.coord_conv import CoordConv


"""# range of different seeds to try"""
SEED_RANGE = range(0, 10)

"""# Setting up the results directory"""
RESULTS_DIR = "results/positional_bias/"

"""# Other Parameters
Choose the height and width to be at least 64 so that the influence of positional bias is more evident.
"""
HEIGHT, WIDTH = 64, 64

"""# Setting up possibilities for architecture and layer choice"""
CONV_ARCS = [GeoConv2D, CoordConv, Conv2D]
NUM_OF_CONV_LAYERS = [1, 2, 3, 4, 5]

"""# Function for creating the Dataset and Models
## The Greek Number Generator Function
"""

def greek_number(i, shift=(0, 0), canvas_size=(HEIGHT, WIDTH)):
    h, w = canvas_size
    c_h, c_w = int(h/2), int(w/2)
    x_shift, y_shift = shift
    canvas = np.zeros((h, w), dtype=np.float32)

    if i % 3 == 1:
        canvas[c_h-5+x_shift:c_h+5+x_shift, c_w-1+y_shift:c_w+1+y_shift] = 1.
    elif i % 3 == 2:
        canvas[c_h-5+x_shift:c_h+5+x_shift, c_w-3+y_shift:c_w-1+y_shift] = 1.
        canvas[c_h-5+x_shift:c_h+5+x_shift, c_w+1+y_shift:c_w+3+y_shift] = 1.
    else:
        canvas[c_h-5+x_shift:c_h+5+x_shift, c_w-5+y_shift:c_w-3+y_shift] = 1.
        canvas[c_h-5+x_shift:c_h+5+x_shift, c_w-1+y_shift:c_w+1+y_shift] = 1.
        canvas[c_h-5+x_shift:c_h+5+x_shift, c_w+3+y_shift:c_w+5+y_shift] = 1.

    return canvas.reshape((1, h, w, 1))


"""### A sample image from the Dataset"""

plt.imshow(greek_number(3).squeeze())
plt.show()


"""## Model Creator"""


def build_model(conv_arc, num_of_conv_layers = 1):
    if num_of_conv_layers == 1:
        model = Sequential([
            conv_arc(filters=1, kernel_size=3, strides=2, activation='relu', use_bias=False),
            Flatten(),
            Dense(3, activation='softmax')
        ])
        model.compile(optimizer='adam', loss='categorical_crossentropy')
        return model
    elif num_of_conv_layers == 2:
        model = Sequential([
            conv_arc(filters=1, kernel_size=3, strides=2, activation='relu', use_bias=False),
            conv_arc(filters=2, kernel_size=3, strides=2, activation='relu', use_bias=False),
            Flatten(),
            Dense(3, activation='softmax')
        ])
        model.compile(optimizer='adam', loss='categorical_crossentropy')
        return model
    elif num_of_conv_layers == 3:
        model = Sequential([
            conv_arc(filters=1, kernel_size=3, strides=2, activation='relu', use_bias=False),
            conv_arc(filters=2, kernel_size=3, strides=2, activation='relu', use_bias=False),
            conv_arc(filters=4, kernel_size=3, strides=2, activation='relu', use_bias=False),
            Flatten(),
            Dense(3, activation='softmax')
        ])
        model.compile(optimizer='adam', loss='categorical_crossentropy')
        return model

    if num_of_conv_layers == 4:
        model = Sequential([
            conv_arc(filters=1, kernel_size=3, strides=2, activation='relu', use_bias=False),
            conv_arc(filters=2, kernel_size=3, strides=2, activation='relu', use_bias=False),
            conv_arc(filters=4, kernel_size=3, strides=2, activation='relu', use_bias=False),
            conv_arc(filters=8, kernel_size=3, strides=2, activation='relu', use_bias=False),
            Flatten(),
            Dense(3, activation='softmax')
        ])
        model.compile(optimizer='adam', loss='categorical_crossentropy')
        return model
    elif num_of_conv_layers == 5:
        model = Sequential([
            conv_arc(filters=1, kernel_size=3, strides=2, activation='relu', use_bias=False),
            conv_arc(filters=2, kernel_size=3, strides=2, activation='relu', use_bias=False),
            conv_arc(filters=4, kernel_size=3, strides=2, activation='relu', use_bias=False),
            conv_arc(filters=8, kernel_size=3, strides=2, activation='relu', use_bias=False),
            conv_arc(filters=16, kernel_size=3, strides=2, activation='relu', use_bias=False),
            Flatten(),
            Dense(3, activation='softmax')
        ])
        model.compile(optimizer='adam', loss='categorical_crossentropy')
        return model


"""## Onehot Function"""


def one_hot(i):
    if i % 3 == 1:
        return np.array([1., 0., 0.]).reshape((1, 3))
    elif i % 3 == 2:
        return np.array([0., 1., 0.]).reshape((1, 3))
    else:
        return np.array([0., 0., 1.]).reshape((1, 3))


each_run_mean_losses = {}
each_run_mean_accuracies = {}
for seed in SEED_RANGE:
    """# Setting up the Parameters
    ## Setting the Seed
    """

    SEED = seed
    tf.random.set_seed(SEED)
    np.random.seed(SEED)
    print(f"<<<<< Experimenting with SEED number {SEED} >>>>>")

    """# Training the Convolutional Models"""
    models = {}
    for conv_arch in CONV_ARCS:
        for num_of_layers in NUM_OF_CONV_LAYERS:
            conv_name = conv_arch.__name__
            print(f"<<<<< Training {conv_name} with {num_of_layers} Conv Layers >>>>>")
            model = build_model(conv_arch, num_of_layers)
            for i in range(600):
                model.fit(greek_number(i), one_hot(i), verbose=0)
            model.summary()
            models[f"{conv_name}_layers_{num_of_layers}"] = model


    """# Evaluation"""
    losses = {}
    accuracies = {}
    for model_name, _ in models.items():
        loss, acc = [], []
        losses[model_name] = loss
        accuracies[model_name] = acc


    x_range = int((WIDTH - 10) / 2)
    y_range = int((HEIGHT - 10) / 2)

    for model_name, model in models.items():
        print(f"<<<<< Evaluating {model_name} >>>>>")
        for x_shift in range(-x_range,  x_range+1):
            for y_shift in range(-y_range, y_range+1):
                for i in range(1, 4):
                    img = greek_number(i, (x_shift, y_shift))
                    target = one_hot(i)
                    pred = model(img)
                    loss = categorical_crossentropy(target, pred)
                    losses[model_name].append(loss)
                    pred_num = np.argmax(pred) + 1
                    if int(pred_num) == i:
                        accuracies[model_name].append(1.)
                    else:
                        accuracies[model_name].append(0.)


    """# Results
    Here we observe that despite having the highest number of parameters, Coordconv has the lowest accuracy and the highest loss. 
    """

    losses_summary = ["Loss"]
    accuracies_summary = ["Accuracy"]
    for model_name, _ in models.items():
        losses_summary.append(np.mean(losses[model_name]))
        accuracies_summary.append(np.mean(accuracies[model_name]))
        each_run_mean_losses[f"seed_{seed}_{model_name}"] = np.mean(losses[model_name])
        each_run_mean_accuracies[f"seed_{seed}_{model_name}"] = np.mean(accuracies[model_name])

    table_header = [""]
    for model_name, _ in models.items():
        table_header.append(model_name)
    table = [table_header, losses_summary, accuracies_summary]

    print(tabulate(table, headers='firstrow'))

    """## Save the Results"""

    os.makedirs(RESULTS_DIR, exist_ok=True)

    with open(RESULTS_DIR + f"results_seed_{SEED}.txt", "w") as f:
        f.write(tabulate(table, headers='firstrow'))


all_runs_average_losses = {}
all_runs_average_accuracies = {}
for model_name, _ in models.items():
    all_runs_average_losses[model_name] = np.mean([each_run_mean_losses[f"seed_{seed}_{model_name}"] for seed in SEED_RANGE])
    all_runs_average_accuracies[model_name] = np.mean([each_run_mean_accuracies[f"seed_{seed}_{model_name}"] for seed in SEED_RANGE])

losses_summary = ["Loss"]
accuracies_summary = ["Accuracy"]
for model_name, _ in models.items():
    losses_summary.append(all_runs_average_losses[model_name])
    accuracies_summary.append(all_runs_average_accuracies[model_name])
table_header = [""]
for model_name, _ in models.items():
    table_header.append(model_name)
table = [table_header, losses_summary, accuracies_summary]

print(tabulate(table, headers='firstrow'))


with open(RESULTS_DIR + f"results_aggregated.txt", "w") as f:
    f.write(tabulate(table, headers='firstrow'))

