################################################################################
# Util functions to perform knee postprocessing. That is, for when running 
# DF-Locate for more iterations and obtain the optimal iteration from the knee.
################################################################################

import os
import numpy as np
import pandas as pd
import itertools
import pickle

from sklearn.metrics import f1_score
import warnings
import matplotlib.pyplot as plt

from copy import deepcopy
from matplotlib.ticker import FuncFormatter
from typing import List
from pandas import DataFrame
from numpy import ndarray
from utils._early_stopping import EarlyStopping

from kneed import KneeLocator
from scipy.interpolate import UnivariateSpline
from scipy.stats import linregress

from scipy.signal import savgol_filter
from ._knee_location import _interpolate, _opening_left_right


def update_experiment(experiment, iteration):
    n_iters = iteration + 1
    
    # Store smoothed balanced accuracy and update attributes
    # information from iterations after the knee
    experiment['runtime_'] = experiment['runtime_'][:n_iters]
    experiment['corrupted_features_'] = experiment['corrupted_features_'][:n_iters]
    experiment['importances_'][experiment['ranking_'] >= n_iters] = 0
    experiment['mask_'][experiment['ranking_'] >= n_iters] = False
    experiment['ranking_'][experiment['ranking_'] >= n_iters] = 0
    experiment['n_corrupted_features_'] = experiment['corrupted_features_'][-1]
    experiment['n_iters_'] = int(n_iters)

    for key in experiment['scores_'].keys():
        experiment['scores_'][key] = experiment['scores_'][key][:n_iters]        

    if experiment['return_estimator']:
        for key in experiment['estimators_'].keys():
            experiment['estimators_'][key] = experiment['estimators_'][key][:n_iters]
            
    return experiment

def early_stopping(metric, patience=5):
    early_stopping = EarlyStopping(patience=patience, verbose=False)
    for iteration, x in enumerate(metric):
        early_stopping(x)
        if early_stopping.early_stop:
            break
    
    return iteration

def apply_patience(shift_location_df, patience):

    print("Applying patience:", patience)
    
    patience_df_all = pd.DataFrame()

    for i in range(shift_location_df.shape[0]):
        if i % 20 == 0:
            print("Experiment:", i)

        experiment = shift_location_df.loc[i]
        experiment = deepcopy(dict(experiment))

        assert len(experiment['corrupted_features_']) == len(experiment['scores_']['mean_test_balanced_accuracy'])

        iterations = early_stopping(experiment['scores_']['mean_test_balanced_accuracy'], patience=patience)
        experiment = update_experiment(experiment, iterations)

        patience_df_all = patience_df_all.append(experiment, ignore_index=True)

    patience_df_all['n_iters_'] = patience_df_all['n_iters_'].astype(int)
    patience_df_all['n_iters_'] = patience_df_all['n_iters_'].reset_index(drop=True)
    
    return patience_df_all

def margin_(metric, margin):
    for iteration, x in enumerate(metric):
        if x < (0.5 + margin):
            break
    
    return iteration

def apply_margin(shift_location_df, margin):
    
    print("Applying margin:", margin)
    
    margin_df_all = pd.DataFrame()

    for i in range(shift_location_df.shape[0]):
        if i % 20 == 0:
            print("Experiment:", i)

        experiment = shift_location_df.loc[i]
        experiment = deepcopy(dict(experiment))

        assert len(experiment['corrupted_features_']) == len(experiment['scores_']['mean_test_balanced_accuracy'])

        iterations = margin_(experiment['scores_']['mean_test_balanced_accuracy'], margin=margin)
        experiment = update_experiment(experiment, iterations)
        margin_df_all = pd.concat([margin_df_all, pd.DataFrame.from_dict(experiment, orient='index').transpose()], ignore_index=True)

    margin_df_all['n_iters_'] = margin_df_all['n_iters_'].astype(int)
    margin_df_all['n_iters_'] = margin_df_all['n_iters_'].reset_index(drop=True)

    return margin_df_all

def _knee_locator(x, y, curve="convex", direction="decreasing", 
                  online=False, S=1.0) -> int:
    """
    Find the knee point (or "elbow" ) in a curve. The knee point is the point of 
    maximum curvature and is used as a heuristic to determine the optimal number 
    of features to be removed.
    
    Parameters
    ----------
    x : ndarray
        Sorted x-axis values in increasing order.
    y : ndarray
        Corresponding y-axis values.
    curve : str, default='convex'
        Type of curve to fit. Can be "concave" or "convex".
    direction : str, default='decreasing'
        The direction to look for a knee point. Can be "increasing" or 
        "decreasing".
    online : default=False
        When set to True, it updates the old knee values if necessary.
    S : default=1.0
        Sensitivity for knee location. It is a measure of how many "flat" points
        are expected in the unmodified data curve before declaring a knee.
    Returns
    -------
    knee : int
        The x-axis coordinate of the knee point.
    w : str
        Warning text used for debugging knee locator.
    """
    x = np.round(x, 2)
    y = np.round(y, 2)
    
    x_ = [x[0]]
    y_ = [y[0]]
    for i in range(1, len(y)):
        if y[i] != y[i-1]:
            x_.append(x[i])
            y_.append(y[i])
    
    x = np.array(x_)
    y = np.array(y_)
    
    # Find the first index such that y < 0.8 in order to locate the knee after this
    # If no index is found, the knee is found across all points
    index = np.argwhere(np.array(y) < 0.8)[0][0] if any(np.array(y) < 0.8) else 0
    
    if index > 0:
        max_y = np.max(y)
    
    # Filter the curve to only consider points beyond the threshold
    x, y = x[index:], y[index:]
    
    S_ = S+50
    extend = True
    
    if extend:
        
        if len(x) == 1:
            m = 1
        else:
            m = int(max([x[i+1] - x[i] for i in range(len(x)-1)]))
            
        last_value = x[-1] + m * S_
        
        x = np.concatenate((x, list(range(int(x[-1]+m), last_value+1, m))))
        y = np.concatenate((y, np.full(S_, y[-1])))
    
    x = np.array([x[0]-1] + list(x))
    y = np.array([1.0] + list(y))
    
    # Define warning message initially empty
    warning_knee = ""
    
    # If there is only one point, return it as the knee
    if len(x) == 1: 
        return x[1], warning_knee, x, y
    
    # Suppress specific warning message from KneeLocator
    with warnings.catch_warnings(record=True) as w:
        #warnings.simplefilter("ignore")
        
        # Find the knee using the KneeLocator package        
        knee_locator = KneeLocator(x, y, curve=curve, direction=direction, 
                                   online=online, S=S)
        knee = knee_locator.knee

    # Concatenate all warning messages from knee location
    warning_knee = "\n".join([str(warning.message) for warning in w])

    # If no knee was found, return the last point
    if knee is not None:
        if knee == -1:
            knee = x[1]
        
        return knee, warning_knee, x, y
    
    print('Knee not found.')
    
    # If no knee was found, return the last point    
    return x[1], warning_knee, x, y

def knee_location(experiment, find_best='knee-balanced', window_length=5, 
                      polyorder=4, S=7, online=False):
    """
    Find the correct number of corrupted features by finding the knee of the 
    curve representing the balanced accuracy of the estimator vs the number 
    of removed features.

    Parameters
    ----------
    find_best : None or 'knee-balanced', default=None
        If 'knee-balanced', the correct number of features to eliminate is 
        determined by finding the knee of the curve representing the balanced 
        accuracy of the estimator vs the number of removed features.
    window_length : None or int, default=5
        Useful only when ``find_best`` == 'knee-balanced'. Used to determine the 
        length of the filter window for Savitzky-Golay filter. The window length 
        is computed as: `max(5, (delta*window_length)// 2*2+1)`, where delta is 
        the mean distance between ``corrupted_features_`` points.
    polyorder : None or int, default=4
        Useful only when ``find_best`` == 'knee-balanced'. The polyorder used to 
        fit the samples for Savitzky-Golay filter.
    S : None or int, default=7
        Useful only when ``find_best`` == 'knee-balanced'. Sensitity for knee 
        location. It is a measure of how many “flat” points are expected in 
        the unmodified data curve before declaring a knee.
    online : None or bool
        Useful only when ``find_best`` == 'knee-balanced'. When set to True, it 
        "corrects" old knee values if necessary.

    Returns
    -------
    self : object
        DFiltering with computed attributes.
    """
    experiment['find_best'] = find_best
    experiment['window_length'] = window_length
    experiment['polyorder'] = polyorder
    experiment['S'] = S
    experiment['online'] = online

    if experiment['n_iters_'] == 1:
        # Make a copy of attributes needed for plotting
        experiment['plot_'] = {}
        experiment['plot_']['corrupted_features'] = experiment['corrupted_features_']
        experiment['plot_']['mean_test_balanced_accuracy'] = experiment['scores_']['mean_test_balanced_accuracy']
        experiment['plot_']['mean_test_balanced_accuracy_smooth'] = experiment['scores_']['mean_test_balanced_accuracy']
        experiment['plot_']['x'] = experiment['corrupted_features_']
        experiment['plot_']['y'] = experiment['scores_']['mean_test_balanced_accuracy']
        experiment['warning_knee_'] = ''
        return experiment

    # Compute the mean distance between the points in ``corrupted_features_``
    delta = experiment['corrupted_features_'][-1]/(experiment['n_iters_']-1)

    # Compute the window size for the Savitzky-Golay smoothing
    window = max(5, (delta*experiment['window_length'])// 2 * 2 + 1)
    
    # Interpolate the data so that the distance between 
    # ``interpolated_features`` points is 1
    interpolated_features, interpolated_accuracies = _interpolate(
        experiment['corrupted_features_'], experiment['scores_']['mean_test_balanced_accuracy']
    )

    # Apply the Savitzky-Golay smoothing
    interpolated_accuracies_smooth = savgol_filter(interpolated_accuracies, window, 
                                                   experiment['polyorder'], mode='nearest')
    
    # Force each point in the right is <= than each point in the left
    interpolated_accuracies_smooth = _opening_left_right(interpolated_accuracies_smooth)

    # Truncate all values below 0.5 to 0.5
    interpolated_accuracies_smooth[interpolated_accuracies_smooth < 0.5] = 0.5

    # Remove interpolation values
    balanced_accuracy_smooth = interpolated_accuracies_smooth[experiment['corrupted_features_']]

    # Find the knee from the smoothed balanced accuracy curve
    corrupted_features_knee, warning_knee, x, y = _knee_locator(
        experiment['corrupted_features_'], balanced_accuracy_smooth, curve="convex", 
        direction="decreasing", online=experiment['online'], S=experiment['S']
    )

    # Find the iteration with the correct number of corrupted features
    iteration_knee = experiment['corrupted_features_'].index(corrupted_features_knee)+1
    
    # Make a copy of attributes needed for plotting
    experiment['plot_'] = {}
    experiment['plot_']['corrupted_features'] = experiment['corrupted_features_']
    experiment['plot_']['mean_test_balanced_accuracy'] = experiment['scores_']['mean_test_balanced_accuracy']
    experiment['plot_']['mean_test_balanced_accuracy_smooth'] = balanced_accuracy_smooth
    experiment['plot_']['x'] = x
    experiment['plot_']['y'] = y

    # Store smoothed balanced accuracy and update attributes
    # information from iterations after the knee
    experiment['runtime_'] = experiment['runtime_'][:iteration_knee]
    experiment['corrupted_features_'] = experiment['corrupted_features_'][:iteration_knee]
    experiment['importances_'][experiment['ranking_'] >= iteration_knee] = 0
    experiment['mask_'][experiment['ranking_'] >= iteration_knee] = False
    experiment['ranking_'][experiment['ranking_'] >= iteration_knee] = 0
    experiment['n_corrupted_features_'] = corrupted_features_knee 
    experiment['n_iters_'] = iteration_knee
    experiment['warning_knee_'] = warning_knee

    for key in experiment['scores_'].keys():
        experiment['scores_'][key] = experiment['scores_'][key][:iteration_knee]        

    if experiment['return_estimator']:
        for key in experiment['estimators_'].keys():
            experiment['estimators_'][key] = experiment['estimators_'][key][:iteration_knee]
    
    return experiment

def knee_postprocessing(shift_location_df, knee_location_params, stopping_criterion_params):
    
    shift_location_df = shift_location_df.reset_index(drop=True)
    
    if stopping_criterion_params["patience"] is not None:
        shift_location_df = apply_patience(shift_location_df, stopping_criterion_params["patience"])
    
    if stopping_criterion_params["margin"] is not None:
        shift_location_df = apply_margin(shift_location_df, stopping_criterion_params["margin"])

    knee_location_df_all = pd.DataFrame()
    shift_location_df["Runtime"] = shift_location_df["runtime_"].apply(lambda x : np.sum(x))
    
    for i in range(shift_location_df.shape[0]):
        experiment = shift_location_df.loc[i]
        experiment = deepcopy(dict(experiment))
        
        knee_location_df = knee_location(deepcopy(experiment), find_best=knee_location_params["find_best"], 
                                   window_length=knee_location_params["window_length"], 
                                   polyorder=knee_location_params["polyorder"], 
                                   S=knee_location_params["S"], online=knee_location_params["online"])
        
        # Add F1 Score in the manipulated feature location
        # Add total runtime across all iterations
        knee_location_df["F1 Score"] = f1_score(knee_location_df["y_true"], knee_location_df["mask_"], zero_division=1)
        
        knee_location_df_all = pd.concat([knee_location_df_all, 
                                          pd.DataFrame.from_dict(knee_location_df, orient='index').transpose()], ignore_index=True)
    
    knee_location_df_all = knee_location_df_all.reset_index(drop=True)
        
    return knee_location_df_all
