#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import warnings

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import numpy as np
from sklearn.inspection import DecisionBoundaryDisplay

def plot_data_and_detector(model, 
                           normal_data, 
                           normal_labels=None,
                           anomalies=None,
                           anomaly_labels=None,
                           plot_min_x=-12, 
                           plot_max_x=12, 
                           plot_min_y=-12, 
                           plot_max_y=12, 
                           resolution=500, 
                           n_contour_levels=20, #might be lower when upper_perfect_threshold is high
                           score_label="Reconstruction loss (MSE)",
                           plot_eigenvector=False,
                           upper_perfect_threshold=0.1,
                           log_scale=False,
                           colorbar_spacing="proportional",
                           normal_data_opacity=1,
                           anomaly_opacity=1,
                           normal_data_markersize=10,
                           anomaly_data_markersize=100,
                           scatter_edgecolor=None,
                           pure_plot=False,
                           no_xticks=True,
                           scatter_edgewidth=1): #pure plot is without legend/colorbar
    
    scatter_plot_legend = False if pure_plot else "auto"

    x_linspace = np.linspace(plot_min_x, plot_max_x, resolution)
    y_linspace = np.linspace(plot_min_y, plot_max_y, resolution)
    
    grid_spec = np.array([[plot_min_x,plot_min_y], [plot_max_x,plot_max_y]])
    
    xx0, xx1 = np.meshgrid(x_linspace, y_linspace)

    X_grid = np.c_[xx0.ravel(), xx1.ravel()]
    X_pred = model.predict(X_grid)

    if log_scale:
        temp_contour_levels = np.logspace(np.log10(np.min(X_pred)), np.log10((np.max(X_pred))), n_contour_levels)
    else:
        temp_contour_levels = np.linspace(np.min(X_pred), np.max(X_pred), n_contour_levels)
    
    contour_levels = np.zeros(2)

    contour_levels[0] = 0
    contour_levels[1] = upper_perfect_threshold
    #Only add those contour levels to colormap where level > upper_perfect_threshold
    contour_levels = np.concatenate([contour_levels, temp_contour_levels[np.argmax(temp_contour_levels > contour_levels[1]):]])

    # Generate viridis colors
    viridis = plt.cm.viridis
    colors = viridis(np.linspace(0, 1, len(contour_levels) - 1))

    # Replace the first color with red
    custom_colors = np.vstack([[1, 0, 0, 1], colors])
 
    DB_kwargs = {"levels":contour_levels, "colors":custom_colors}
    ax = plt.gca()
    disp = DecisionBoundaryDisplay.from_estimator(model, grid_spec, response_method="predict",
                                                  xlabel="X", ylabel="Y",
                                                  alpha=0.5,
                                                  ax=ax,
                                                  **DB_kwargs)
    if not pure_plot:
        colorbar = plt.colorbar(disp.ax_.collections[0], ax=ax, spacing=colorbar_spacing)
        colorbar.set_label(score_label)

    if normal_labels is None:
        sns.scatterplot(x=normal_data[:,0], y=normal_data[:,1], color="black", label="Normal data", alpha=normal_data_opacity, s=normal_data_markersize, edgecolor=scatter_edgecolor, legend=scatter_plot_legend, linewidth=scatter_edgewidth)
    else:
        unique_labels = np.unique(normal_labels)
        for label in unique_labels:
            sns.scatterplot(x=normal_data[normal_labels==label,0], y=normal_data[normal_labels==label,1], label="Normal "+str(label), alpha=normal_data_opacity, s=normal_data_markersize, edgecolor=scatter_edgecolor, legend=scatter_plot_legend, linewidth=scatter_edgewidth)


    if anomalies is not None:
        if anomaly_labels is None:
            sns.scatterplot(x=anomalies[:,0], y=anomalies[:,1], marker="X", color="black", label="Anomaly", alpha=anomaly_opacity, s=anomaly_data_markersize, edgecolor=scatter_edgecolor, legend=scatter_plot_legend, linewidth=scatter_edgewidth)
        else:
            unique_labels = np.unique(anomaly_labels)
            for label in unique_labels:
                sns.scatterplot(x=anomalies[anomaly_labels==label,0], y=anomalies[anomaly_labels==label,1], marker="X", label="Anomaly "+str(label), alpha=anomaly_opacity, s=anomaly_data_markersize, edgecolor=scatter_edgecolor, legend=scatter_plot_legend, linewidth=scatter_edgewidth)

    
    if plot_eigenvector:
        try:
            eigenvector = model.V[:,0]*40
        except AttributeError:
            AttributeError("plot_eigenvector can only be True if model has attribute 'V' containing eigenvectors")
            
        eigenvector_x = x_linspace * eigenvector[0]
        eigenvector_y = y_linspace * eigenvector[1]
        plt.plot(eigenvector_x, eigenvector_y, label="eigenvector")
    
    plt.xlim(plot_min_x,plot_max_x)
    plt.ylim(plot_min_y, plot_max_y)
    #plt.legend(loc='upper left')
    

    ax = plt.gca()
    ax.set_aspect('equal', adjustable='box')

    if not pure_plot:
        h,_ = ax.get_legend_handles_labels()
        red_patch = mpatches.Patch(color=[1, 0, 0, 1], alpha=0.5, label="MSE < {threshold:.2f}".format(threshold = upper_perfect_threshold))
        plt.legend(handles=h+[red_patch])
    else: 
        plt.xlabel("")
        plt.ylabel("")
 
    if no_xticks or pure_plot:
        plt.xticks([])  
        plt.yticks([]) 
