import os
import sys
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import layers
import json
from tabulate import tabulate

conf_path = os.getcwd()
sys.path.append(conf_path)

from src.coord_conv import CoordConv
from src.geo_conv import GeoConv2D
from utils.mass_centre import mass_centre_dataset
from utils.plot_tools import show_image


"""# 1. Constants"""
# Random seed
SEED = 47

# Dataset details
INPUT_SHAPE = (32, 32, 1)
IMAGE_SIZE = INPUT_SHAPE[0]
NUM_TRAIN_IMAGES = 100000
NUM_TEST_IMAGES = 20000
VAL_SPLIT = 0.1

# Model Details
LR = 1e-4
OPTIMIZER = tf.keras.optimizers.Adam
LOSS = tf.keras.losses.MeanAbsoluteError
CONV_ARCS = [GeoConv2D, CoordConv, layers.Conv2D]
# RATIOS = [.001, .003, .01, .03, .1, .3, .9]
RATIOS = [.003, .03, .3, .9]
# NUM_OF_CONV_LAYERS = [1, 2, 3]
FILTERS = [[1], [1, 2]]
KERNEL_SIZES = [[3], [3, 3]]
STRIDES = [[2], [2, 2]]
BATCH_SIZE = 256
EPOCHS = 2  # 75


"""# 2. Defining the Model Creator"""
# A function to define the models
def build_model(conv_arch, filters, kernel_sizes, strides, input_shape=INPUT_SHAPE):
    # Define the model
    # Input layer
    model = keras.Sequential([
        layers.InputLayer(input_shape=input_shape)
    ])

    # Hidden conv and dropout layers
    for i in range(len(filters)):
        # Add conv layer
        model.add(
            conv_arch(
                filters=filters[i],
                kernel_size=kernel_sizes[i],
                strides=strides[i],
                activation='relu'
            )
        )

        # Add dropout
        model.add(layers.Dropout(.2))

    # Output layer
    model.add(layers.Flatten())
    model.add(layers.Dense(2, activation='relu'))

    # Compile the model
    model.compile(optimizer=OPTIMIZER(LR), loss=LOSS())

    # Return the model
    return model


"""# 3. Preparing the results directory"""
# Preparing for saving the results
RESULTS_DIR = f'results/mass_centre/'

HISTORIES_DIR = RESULTS_DIR + "histories/"
MODELS_DIR = RESULTS_DIR + "models/"
GRAPHS_DIR = RESULTS_DIR + "graphs/"

os.makedirs(HISTORIES_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(GRAPHS_DIR, exist_ok=True)


# """# 4. The Mass Centre dataset"""
# # Create a dataset
# TRAIN_DATA = mass_centre_dataset(num_imgs=10, length=32, pts_ratio=.01, seed=SEED)
#
# train_imgs = TRAIN_DATA["imgs"]
# train_labels = TRAIN_DATA["labels"]
#
# del TRAIN_DATA
#
#
# """## 4.1. Show some images from the dataset"""
# idx = 0
# show_image(np.squeeze(train_imgs[idx]))
# print(train_labels[idx])


"""# 5. Evaluation
## 5.1. Setting up the seed
"""
# Set up the seed for NumPy and TensorFlow
np.random.seed(SEED)
tf.random.set_seed(SEED)


"""## 5.2. Train the Different Architectures over Various Datasets, and save the results"""
# Evaluate different train ratios
for train_pts_ratio in RATIOS:
    print(f"<<< Train Dataset Ratio: {train_pts_ratio} >>>")

    # Create the train dataset
    TRAIN_DATA = mass_centre_dataset(NUM_TRAIN_IMAGES, IMAGE_SIZE, train_pts_ratio, SEED)
    train_imgs = TRAIN_DATA["imgs"]
    train_labels = TRAIN_DATA["labels"]
    del TRAIN_DATA

    # Iterate over conv architectures
    for conv_arc in CONV_ARCS:
        conv_name = conv_arc.__name__

        # Iterate over different sizes
        for i in range(len(FILTERS)):
            num_layers = len(FILTERS[i])
            print(f"<<< Training {conv_name} Model with {num_layers} Conv Layers >>>")

            # Build the model
            model = build_model(conv_arc, FILTERS[i], KERNEL_SIZES[i], STRIDES[i])

            # Train the model
            history = model.fit(
                train_imgs,
                train_labels,
                batch_size=BATCH_SIZE,
                epochs=EPOCHS,
                validation_split=VAL_SPLIT,
                shuffle=False
            ).history

            # Save the weights
            model.save_weights(MODELS_DIR + f"{conv_name}-ratio_{train_pts_ratio}-layers_{num_layers}.h5")

            # Save the history
            json.dump(
                history,
                open(HISTORIES_DIR + f"{conv_name}-ratio_{train_pts_ratio}-layers_{num_layers}.json", 'w')
            )

"""# Evaluate the models for different train/test ratios"""
# Iterate over the architectures
evaluations = {}
for conv_arc in CONV_ARCS:
    conv_name = conv_arc.__name__

    # Iterate over the number of layers
    evals_layers = {}
    for i in range(len(FILTERS)):
        num_layers = len(FILTERS[i])
        print(f"<<< Evaluating {conv_name} with {num_layers} Conv Layers >>>")

        # Iterate over different train ratios
        evals_train = {}
        for train_pts_ratio in RATIOS:
            print(f"<<< Train Ratio: {train_pts_ratio} >>>")

            # Build the model and load the weights
            model = build_model(conv_arc, FILTERS[i], KERNEL_SIZES[i], STRIDES[i])
            model.load_weights(MODELS_DIR + f"{conv_name}-ratio_{train_pts_ratio}-layers_{num_layers}.h5")

            # Iterate over different test ratios
            evals_test = {}
            for test_pts_ratio in RATIOS:
                print(f"<<< Trest Ratio: {test_pts_ratio} >>>")

                # Test dataset with a given ratio
                TEST_DATA = mass_centre_dataset(NUM_TEST_IMAGES, IMAGE_SIZE, test_pts_ratio, SEED)
                test_imgs = TEST_DATA["imgs"]
                test_labels = TEST_DATA["labels"]
                del TEST_DATA

                evals_test[f"test:{test_pts_ratio}"] = round(model.evaluate(test_imgs, test_labels), 4)

            evals_train[f"train:{train_pts_ratio}"] = evals_test

        evals_layers[f"layers:{num_layers}"] = evals_train

    evaluations[conv_name] = evals_layers

print(evaluations)


"""## Evaluation Tables"""
tables = {}
for train_pts_ratio in RATIOS:
    table = [[""] + RATIOS]
    for conv_arc in CONV_ARCS:
        conv_name = conv_arc.__name__
        for i in range(len(FILTERS)):
            num_layers = len(FILTERS[i])
            row = [f"{conv_name} with {num_layers} Conv Layers"]
            for test_pts_ratio in RATIOS:
                row.append(evaluations[conv_name][f"layers:{num_layers}"][f"train:{train_pts_ratio}"][f"test:{test_pts_ratio}"])
            table.append(row)
    tables[train_pts_ratio] = table

for train_pts_ratio in RATIOS:  
    print(f"Training set points ratio: {train_pts_ratio}")
    print(tabulate(tables[train_pts_ratio], headers="firstrow", tablefmt='fancy_grid'))
    with open("%slosses_train_ratio:_%f.txt" % (RESULTS_DIR, train_pts_ratio), "w") as f:
        f.write(tabulate(tables[train_pts_ratio], headers="firstrow", tablefmt='fancy_grid'))


"""## Calculate the average"""
avgs = {}
for conv_arc in CONV_ARCS:
    conv_name = conv_arc.__name__
    for num_of_layers in NUM_OF_CONV_LAYERS:
        avgs[f"{conv_name} with {num_of_layers} Conv Layers"] = 0
        for train_pts_ratio in RATIOS:
            for test_pts_ratio in RATIOS:
                avgs[f"{conv_name} with {num_of_layers} Conv Layers"] += evaluation[f"{conv_name}_trainRatio:{train_pts_ratio}_testRatio:{test_pts_ratio}_layers:{num_of_layers}"]
        avgs[f"{conv_name} with {num_of_layers} Conv Layers"] /= len(RATIOS)**2

print(avgs)


"""## Calculating normalised average"""
sum_over_arcs = {}
for train_pts_ratio in RATIOS:
    for test_pts_ratio in RATIOS:
        sum_over_arcs[f"trainRatio:{train_pts_ratio}_testRatio:{test_pts_ratio}"] = 0
        for conv_arc in CONV_ARCS:
            for num_of_layers in NUM_OF_CONV_LAYERS:
                conv_name = conv_arc.__name__
                sum_over_arcs[f"trainRatio:{train_pts_ratio}_testRatio:{test_pts_ratio}"] += evaluation[f"{conv_name}_trainRatio:{train_pts_ratio}_testRatio:{test_pts_ratio}_layers:{num_of_layers}"]

normalised_evaluation = {}
for train_pts_ratio in RATIOS:
    for test_pts_ratio in RATIOS:
        for conv_arc in CONV_ARCS:
            conv_name = conv_arc.__name__
            for num_of_layers in NUM_OF_CONV_LAYERS:
                normalised_evaluation[f"{conv_name}_trainRatio:{train_pts_ratio}_testRatio:{test_pts_ratio}_layers:{num_of_layers}"] = \
                evaluation[f"{conv_name}_trainRatio:{train_pts_ratio}_testRatio:{test_pts_ratio}_layers:{num_of_layers}"] / \
                sum_over_arcs[f"trainRatio:{train_pts_ratio}_testRatio:{test_pts_ratio}"]

normalised_avgs = {}
for conv_arc in CONV_ARCS:
    conv_name = conv_arc.__name__
    for num_of_layers in NUM_OF_CONV_LAYERS:
        normalised_avgs[f"{conv_name} with {num_of_layers} Conv Layers"] = 0
        for train_pts_ratio in RATIOS:
            for test_pts_ratio in RATIOS:
                normalised_avgs[f"{conv_name} with {num_of_layers} Conv Layers"] += normalised_evaluation[f"{conv_name}_trainRatio:{train_pts_ratio}_testRatio:{test_pts_ratio}_layers:{num_of_layers}"]
        normalised_avgs[f"{conv_name} with {num_of_layers} Conv Layers"] /= len(RATIOS)**2

print(normalised_avgs)


"""## Evaluation Results"""
avg_results = ['Avg. loss']
for num_of_layers in NUM_OF_CONV_LAYERS:
    f"{conv_name} with {num_of_layers} Conv Layers"
    avg_results.append(avgs[f"Conv2D with {num_of_layers} Conv Layers"])
    avg_results.append(avgs[f"GeoConv2D with {num_of_layers} Conv Layers"])
    avg_results.append(avgs[f"CoordConv with {num_of_layers} Conv Layers"])
nor_avg_results = ['Normalised avg. loss']
for num_of_layers in NUM_OF_CONV_LAYERS:
    f"{conv_name} with {num_of_layers} Conv Layers"
    nor_avg_results.append(normalised_avgs[f"Conv2D with {num_of_layers} Conv Layers"])
    nor_avg_results.append(normalised_avgs[f"GeoConv2D with {num_of_layers} Conv Layers"])
    nor_avg_results.append(normalised_avgs[f"CoordConv with {num_of_layers} Conv Layers"])

table = [['', 'GeoConv with 1 Conv Layer', 'CoordConv with 1 Conv Layer', 'Simple Conv with 1 Conv Layer',
          'GeoConv with 2 Conv Layers', 'CoordConv with 2 Conv Layers', 'Simple Conv with 2 Conv Layers',
          'GeoConv with 3 Conv Layers', 'CoordConv with 3 Conv Layers', 'Simple Conv with 3 Conv Layers'],
         avg_results, nor_avg_results]

print(tabulate(table, headers='firstrow'))
with open(RESULTS_DIR + "avg_performances.txt", "w") as f:
    f.write(tabulate(table, headers='firstrow'))
