#!/usr/bin/env python3
# -*- coding: utf-8 -*-


#%% imports

import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn

from src.models.PCA import PCA
from src.models.AE import AE_anomaly_detector
from src.plotting_functions.plot_data_and_detector import plot_data_and_detector

from src.data_generators.banana import generate_banana_data


#%% set IO options
figures_folder = "Figures/horrible_hyperplanes"
os.makedirs(os.path.join(figures_folder), exist_ok=True)

#%% set experiment parameters
np.random.seed(0)


plot_min_x, plot_min_y = -25, -25
plot_max_x, plot_max_y = 25, 25

points_per_dimension = 500


#%% Synthesize and calculate

gaussian_2d_0 = np.random.multivariate_normal([0,0], [[0.5,0],[4,4]], 100)

gaussian_2d_1 = np.random.multivariate_normal([-7.5,7.5], [[0.5,0],[4,4]], 100)

gaussian_2d_2 = np.random.multivariate_normal([7.5,-7.5], [[0.5,0],[4,4]], 100)

banana_2d = generate_banana_data(100, 1, x_range=5, centered=True)

gaussian_2d_3 = np.random.multivariate_normal([-7.5,-7.5], [[0.5,0],[4,4]], 100)

gaussian_2d_4 = np.random.multivariate_normal([7.5,7.5], [[0.5,0],[4,4]], 100)


#anomalies = np.array([[10,10]])

#contaminated_data = np.row_stack([gaussian_2d_1, anomalies])

datasets = {"Single_2D_Gaussian":gaussian_2d_0,
            "Double_2D_Gaussians_1":np.vstack([gaussian_2d_1, gaussian_2d_2]),
            "Double_2D_Gaussians_2":np.vstack([gaussian_2d_3, gaussian_2d_4]),
            "Banana":banana_2d
            }

#%% define hyperparameters for all distinct AEs per dataset:
model_specs = { "Single_2D_Gaussian":
                {
                    # "Linear":{
                    #     "encoding_dim":1,
                    #     "hidden_layer_dims":[5],
                    #     "activation_function":nn.Linear,
                    #     "linear_layer_params":{"bias":True},
                    #     "activation_params":{},
                    #     "learning_rate":0.05,
                    #     "epochs":100,
                    #     "batch_size":None
                    # },
                    # "ReLU":{
                    #     "encoding_dim":1,
                    #     "hidden_layer_dims":[5],
                    #     "activation_function":nn.ReLU,
                    #     "linear_layer_params":{"bias":True},
                    #     "activation_params":{},
                    #     "learning_rate":0.02,
                    #     "epochs":10000,
                    #     "batch_size":None
                    # },
                    "Sigmoid":{
                        "encoding_dim":1,
                        "hidden_layer_dims":[5],
                        "activation_function":nn.Sigmoid,
                        "linear_layer_params":{"bias":True},
                        "activation_params":{},
                        "learning_rate":0.1,
                        "epochs":10000,
                        "batch_size":None
                    }
                },
                "Double_2D_Gaussians_1":
                {
                    # "Linear":{
                    #     "encoding_dim":1,
                    #     "hidden_layer_dims":[5],
                    #     "activation_function":nn.Linear,
                    #     "linear_layer_params":{"bias":True},
                    #     "activation_params":{},
                    #     "learning_rate":0.001,
                    #     "epochs":100,
                    #     "batch_size":None
                    # },
                    # "ReLU":{
                    #     "encoding_dim":1,
                    #     "hidden_layer_dims":[5],
                    #     "activation_function":nn.ReLU,
                    #     "linear_layer_params":{"bias":True},
                    #     "activation_params":{},
                    #     "learning_rate":0.01,
                    #     "epochs":30000,
                    #     "batch_size":None
                    # },
                    "Sigmoid":{
                        "encoding_dim":1,
                        "hidden_layer_dims":[5],
                        "activation_function":nn.Sigmoid,
                        "linear_layer_params":{"bias":True},
                        "activation_params":{},
                        "learning_rate":0.1,
                        "epochs":300,
                        "batch_size":None
                    }
                },
                "Double_2D_Gaussians_2":
                {
                    # "Linear":{
                    #     "encoding_dim":1,
                    #     "hidden_layer_dims":[5],
                    #     "activation_function":nn.Linear,
                    #     "linear_layer_params":{"bias":True},
                    #     "activation_params":{},
                    #     "learning_rate":0.001,
                    #     "epochs":100,
                    #     "batch_size":None
                    # },
                    # "ReLU":{
                    #     "encoding_dim":1,
                    #     "hidden_layer_dims":[5],
                    #     "activation_function":nn.ReLU,
                    #     "linear_layer_params":{"bias":True},
                    #     "activation_params":{},
                    #     "learning_rate":0.002,
                    #     "epochs":40000,
                    #     "batch_size":None
                    # },
                    "Sigmoid":{
                        "encoding_dim":1,
                        "hidden_layer_dims":[5],
                        "activation_function":nn.Sigmoid,
                        "linear_layer_params":{"bias":True},
                        "activation_params":{},
                        "learning_rate":0.1,
                        "epochs":300,
                        "batch_size":None
                    }
                },
                "Banana":
                {
                    # "Linear":{
                    #     "encoding_dim":1,
                    #     "hidden_layer_dims":[5],
                    #     "activation_function":nn.Linear,
                    #     "linear_layer_params":{"bias":True},
                    #     "activation_params":{},
                    #     "learning_rate":0.001,
                    #     "epochs":1000,
                    #     "batch_size":None
                    # },
                    # "ReLU":{
                    #     "encoding_dim":1,
                    #     "hidden_layer_dims":[5],
                    #     "activation_function":nn.ReLU,
                    #     "linear_layer_params":{"bias":True},
                    #     "activation_params":{},
                    #     "learning_rate":0.002,
                    #     "epochs":10000,
                    #     "batch_size":None
                    # },
                    "Sigmoid":{
                        "encoding_dim":1,
                        "hidden_layer_dims":[5],
                        "activation_function":nn.Sigmoid,
                        "linear_layer_params":{"bias":True},
                        "activation_params":{},
                        "learning_rate":0.01,
                        "epochs":10000,
                        "batch_size":None
                    },
                    # "ReLU2layer":{
                    #     "encoding_dim":1,
                    #     "hidden_layer_dims":[100,20],
                    #     "activation_function":nn.ReLU,
                    #     "linear_layer_params":{"bias":True},
                    #     "activation_params":{},
                    #     "learning_rate":0.005,
                    #     "epochs":40000,
                    #     "batch_size":None
                    # },
                    "Sigmoid2layer":{
                        "encoding_dim":1,
                        "hidden_layer_dims":[100,20],
                        "activation_function":nn.Sigmoid,
                        "linear_layer_params":{"bias":True},
                        "activation_params":{},
                        "learning_rate":0.05,
                        "epochs":10000,
                        "batch_size":None
                    },
                    # "ReLU3layer":{
                    #     "encoding_dim":1,
                    #     "hidden_layer_dims":[100,50,20],
                    #     "activation_function":nn.ReLU,
                    #     "linear_layer_params":{"bias":True},
                    #     "activation_params":{},
                    #     "learning_rate":0.005,
                    #     "epochs":40000,
                    #     "batch_size":None
                    # }
                }
              }

model_seed = 0

#%% run all hyperplane experiments:
for dataset_name in datasets:
    dataset = datasets[dataset_name]

    #PCA:
    pca = PCA(n_PCs=1)
    pca.fit(dataset)


    plt.figure()
    plot_data_and_detector(model=pca, normal_data=dataset, anomalies=None, plot_eigenvector=True, 
                        plot_min_x=plot_min_x, plot_max_x=plot_max_x, plot_min_y=plot_min_y, plot_max_y=plot_max_y, log_scale=True, colorbar_spacing="uniform", normal_data_markersize=30)
    fig_name = os.path.join(figures_folder, dataset_name+"_"+"PCA")
    plt.tight_layout()
    plt.savefig(fig_name+".png", format="png")
    plt.savefig(fig_name+".pdf", format="pdf")
    plt.show()

    # Show various hyperplanes that extend beyond the expected boundaries:
    for model in model_specs[dataset_name]:
        torch.manual_seed(model_seed)

        model_spec = model_specs[dataset_name][model]

        ae = AE_anomaly_detector(**model_spec)

        print(model)
        ae.fit(dataset)

       
        plt.figure()
        ae.plot_loss()
        fig_name = os.path.join(figures_folder, dataset_name+"_"+model+"_hyperplane_seed"+str(model_seed)+"loss_plot")
        plt.tight_layout()
        plt.savefig(fig_name+".png", format="png")
        plt.savefig(fig_name+".pdf", format="pdf")
        plt.show()

        plt.figure()
        plot_data_and_detector(model=ae, normal_data=dataset, anomalies=None, plot_eigenvector=False, 
                            plot_min_x=plot_min_x, plot_max_x=plot_max_x, plot_min_y=plot_min_y, plot_max_y=plot_max_y, log_scale=True, colorbar_spacing="uniform", normal_data_markersize=30)
        fig_name = os.path.join(figures_folder, dataset_name+"_"+model+str(model_spec["epochs"])+"_hyperplane_seed"+str(model_seed))
        plt.tight_layout()
        plt.savefig(fig_name+".png", format="png")
        plt.savefig(fig_name+".pdf", format="pdf")
        plt.show()
