import matplotlib.pyplot as plt
from scipy.spatial import Delaunay
from shapely.ops import cascaded_union, polygonize, unary_union
from shapely.geometry import MultiPoint, MultiLineString, mapping
from pymoo.factory import get_performance_indicator
import warnings
warnings.filterwarnings('ignore')
import pandas as pd
import numpy as np
import tqdm


def frontier_builder(df, alpha=0, verbose=False):
    """
    Modified alphashape algorithm to draw Pareto Front for OFA search.
    Takes a DataFrame of column form [x, y] = [latency, accuracy]

    Params:
    df     - 2 column dataframe in order of 'Latency' and 'Accuracy'
    alpha  - Dictates amount of tolerable 'concave-ness' allowed.
             A fully convex front will be given if 0 (also better for runtime)
    """
    if verbose:
        print('Running front builder')
    df = df[['latency', 'accuracy']]
    points = list(df.to_records(index=False))
    points = MultiPoint(list(points))

    if len(points) < 4 or (alpha is not None and alpha <= 0):
        if verbose:
            print('Alpha=0 -> convex hull')
        result = points.convex_hull
    else:
        coords = np.array([point.coords[0] for point in points])
        tri = Delaunay(coords)
        edges = set()
        edge_points = []
        edge_out = []

        # Loop over triangles
        for ia, ib, ic in tri.vertices:
            pa = coords[ia]
            pb = coords[ib]
            pc = coords[ic]

            # Lengths of sides of triangle
            a = math.sqrt((pa[0] - pb[0])**2 + (pa[1] - pb[1])**2)
            b = math.sqrt((pb[0] - pc[0])**2 + (pb[1] - pc[1])**2)
            c = math.sqrt((pc[0] - pa[0])**2 + (pc[1] - pa[1])**2)

            # Semiperimeter of triangle
            s = (a + b + c) * 0.5

            # Area of triangle by Heron's formula
            # Precompute value inside square root to avoid unbound math error in
            # case of 0 area triangles.
            area = s * (s - a) * (s - b) * (s - c)

            if area > 0:
                area = math.sqrt(area)

                # Radius Filter
                if a * b * c / (4.0 * area) < 1.0 / alpha:
                    for i, j in itertools.combinations([ia, ib, ic], r=2):
                        if (i, j) not in edges and (j, i) not in edges:
                            edges.add((i, j))
                            edge_points.append(coords[[i, j]])

                            if coords[i].tolist() not in edge_out:
                                edge_out.append(coords[i].tolist())
                            if coords[j].tolist() not in edge_out:
                                edge_out.append(coords[j].tolist())

        # Create the resulting polygon from the edge points
        m = MultiLineString(edge_points)
        triangles = list(polygonize(m))
        result = cascaded_union(triangles)

    # Find multi-polygon boundary
    bound = list(mapping(result.boundary)['coordinates'])

    # Cutoff non-Pareto front points
    # note that extreme concave geometries will create issues if bi-sected by line
    df = pd.DataFrame(bound, columns=['x', 'y'])

    # y=mx+b
    left_point = (df.iloc[df.idxmin()[0]][0], df.iloc[df.idxmin()[0]][1])
    right_point = (df.iloc[df.idxmax()[1]][0], df.iloc[df.idxmax()[1]][1])
    m = (left_point[1]-right_point[1])/(left_point[0]-right_point[0])
    b = left_point[1]-(m*left_point[0])

    df = df[df['y'] >= (m*df['x']+b)]
    df.sort_values(by='x', inplace=True)
    df.reset_index(drop=True, inplace=True)

    # Cleanup - insure accuracy is always increasing with latency up the Pareto front
    best_acc = 0
    drop_list = []
    for i in range(len(df)):
        if df.iloc[i]['y'] > best_acc:
            best_acc = df.iloc[i]['y']
        else:
            drop_list.append(i)
    df.drop(df.index[drop_list], inplace=True)
    df.reset_index(drop=True, inplace=True)

    df.columns = ['latency', 'accuracy']

    return df


def load_csv_to_df(filepath, add_config=False, normalize=False, fit=False,
                     scaler=None, sort=False, verbose=False, idx_slicer=None,
                     col_list=['config', 'date', 'latency', 'accuracy']):

    if idx_slicer is not None:
        df = pd.read_csv(filepath, names=col_list).iloc[:idx_slicer]
    else:
        df = pd.read_csv(filepath, names=col_list)

    if sort:
        df = df.sort_values(by=['latency']).reset_index(drop=True)
    if verbose:
        print(filepath)
        print('acc max = {}'.format(df['accuracy'].max()))
        print('lat min = {}'.format(df['latency'].min()))

    if add_config:
        df = df[['config', 'latency', 'accuracy']]
    df = df[['latency', 'accuracy']]

    if normalize:
        if fit == True:
            scaler = MinMaxScaler()
            scaler.fit(df['latency'].values.reshape(-1, 1))
            df['latency'] = scaler.transform(df['latency'].values.reshape(-1, 1)).squeeze()

            return df, scaler
        else:
            df['latency'] = scaler.transform(df['latency'].values.reshape(-1, 1)).squeeze()
            return df
    else:
        return df



def collect_hv(df_results, max_idx=20000, ref_point=None):
    '''
    Calculates the 2D Hypervolume (HV) for a given accuracy and latency.

    df_result - a dataframe with latency and accuracy results
    ref_point - a HV reference point in the 2D objective space
    '''

    hv = get_performance_indicator("hv", ref_point=np.array(ref_point))

    if max_idx > 1000:
        start_interval = np.array(list(range(10, 500, 10)))
        end_interval = np.array(list(range(500, max_idx+1, 100)))
        full_interval = np.concatenate([start_interval, end_interval])
    else:
        full_interval = np.array(list(range(10, max_idx+1, 10)))
    hv_list = list()

    for evals in tqdm.tqdm(full_interval):

        front = frontier_builder(df_results.iloc[:int(evals)])
        front['n_accuracy'] = -front['accuracy']
        hv_list.append(hv.do(front[['latency','n_accuracy']].values))

    for i in range(0, len(hv_list)-1):
        if hv_list[i+1] < hv_list[i]:
            hv_list[i+1] = hv_list[i]

    full_interval = np.insert(full_interval, 0, 1, axis=0)
    hv_list=np.array(hv_list)
    hv_list = np.insert(hv_list, 0, 0, axis=0)

    return hv_list, full_interval