import numpy as np
import matplotlib.pyplot as plt
import postprocess.util, postprocess.stats
from matplotlib.patches import Ellipse
import matplotlib.cm as cm
from matplotlib.legend_handler import HandlerBase
from matplotlib.patches import Rectangle, Circle


fn_map = {
    'GreedyFrac': postprocess.stats.greedy_frac,
    'LastOpt': postprocess.stats.last_opt,
    'SuffixFail': postprocess.stats.suffix_fail,
    'SuffixFail75': lambda x: postprocess.stats.suffix_fail(x,t=60),
    'AveRew': postprocess.stats.ave_reward,
    'AveRewardLastHalf': postprocess.stats.ave_reward_last_half,
    'OptCount': postprocess.stats.opt_count,
    'OptCountLastHalf': postprocess.stats.opt_count_last_half,
    'MinCount': postprocess.stats.min_count,
    'MinCountLastHalf': postprocess.stats.min_count_last_half,
    }


### Some plotting stuff
def get_ellipse_params(x,y):
    arr = np.vstack([np.array(x),np.array(y)])
    mean = np.mean(arr, axis=1)
    cov_matrix = np.cov(arr)
    eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)
    major_axis_index = np.argmax(eigenvalues)
    major_axis = eigenvectors[:, major_axis_index]
    angle = np.arctan2(major_axis[1], major_axis[0])
    width, height = 2 * np.sqrt(2) * np.sqrt(eigenvalues)

    return (mean, width, height, np.degrees(angle))


def scatter_averages(dl,x_fn,y_fn):
    output = dl.get_scatter_data(fn_map[x_fn], fn_map[y_fn])
    lines = {}
    for alg in dl.alg_names:
        if len(output[alg][0]) > 0:
            x = np.mean(output[alg][0])
            y = np.mean(output[alg][1])
            l = plt.scatter(x,y,
                            color=postprocess.util.get_color(alg),
                            label=postprocess.util.get_short_name(alg))
            lines[alg.split("_")[0]] = l
    plt.xlabel(x_fn)
    plt.ylabel(y_fn)
    plt.legend(handles = [lines[alg] for alg in dl.algs])
    
def scatter_quantile(dl,x_fn,y_fn,alpha):
    output = dl.get_scatter_data(fn_map[x_fn], fn_map[y_fn])
    lines = {}
    for alg in dl.alg_names:
        if len(output[alg][0]) > 0:
            x = np.quantile(output[alg][0],alpha)
            y = np.quantile(output[alg][1],alpha)
            l = plt.scatter(x,y,
                            color=postprocess.util.get_color(alg),
                            marker=postprocess.util.get_marker(alg),
                            s=postprocess.util.get_marker_size(alg),
                            label=postprocess.util.get_short_name(alg))
            lines[alg.split("_")[0]] = l
    plt.xlabel(x_fn)
    plt.ylabel(y_fn)
    plt.legend(handles = [lines[alg] for alg in dl.algs])

class HandlerColormap(HandlerBase):
    def __init__(self, cmap, num_stripes=8, **kw):
        HandlerBase.__init__(self, **kw)
        self.cmap = cmap
        self.num_stripes = num_stripes
    def create_artists(self, legend, orig_handle, 
                       xdescent, ydescent, width, height, fontsize, trans):
        stripes = []
        for i in range(self.num_stripes):
            s = Rectangle([xdescent + i * width / self.num_stripes, ydescent], 
                          width / self.num_stripes, 
                          height, 
                          fc=self.cmap((2 * i + 1) / (2 * self.num_stripes)), 
                          transform=trans)
            # s = Circle([xdescent+i*width/self.num_stripes, ydescent],
            #            radius=5, 
            #            fc=self.cmap((2 * i + 1) / (2 * self.num_stripes)), 
            #            transform=trans)
            stripes.append(s)
        return stripes

def scatter_individual(dl,x_fn,y_fn,baselines='ellipse'):
    ## options for baselines: "ellipse", "means" or "scatter"
    output = dl.get_scatter_data(fn_map[x_fn], fn_map[y_fn])
    lines = {}
    
    to_recolor = [alg for alg in dl.alg_names if alg.split("_")[0] not in dl.baselines and len(output[alg][0]) > 0]
    colormap = cm.get_cmap('viridis', len(to_recolor))
    i = 0.
    for alg in dl.alg_names:
        if len(output[alg][0]) > 0:
            if alg.split("_")[0] in dl.baselines and baselines != "scatter":
                ## We will render ellipses for the baselines
                (mean, width, height, angle) = get_ellipse_params(output[alg][0], output[alg][1])
                if baselines == "ellipse":
                    ellipse = Ellipse(mean, width, height, angle=angle, alpha=0.3,
                                  color=postprocess.util.get_color(alg),
                                  label=postprocess.util.get_short_name(alg))
                    plt.gca().add_patch(ellipse)
                l = plt.scatter(mean[0], mean[1],
                                color=postprocess.util.get_color(alg),
                                label=postprocess.util.get_short_name(alg))
                lines[alg.split("_")[0]] = l
            elif baselines != "scatter":
                l = plt.scatter(output[alg][0], output[alg][1],
                                color=colormap(i/len(to_recolor)),
                                label=postprocess.util.get_short_name(alg))
                i += 1.
                lines[alg.split("_")[0]] = l
            else:
                l = plt.scatter(output[alg][0], output[alg][1],
                                color=postprocess.util.get_color(alg),
                                label=postprocess.util.get_short_name(alg))
                lines[alg.split("_")[0]] = l
    plt.xlabel(x_fn)
    plt.ylabel(y_fn)
    gradient_name = to_recolor[0].split("_")[0]
    if baselines == "scatter":
        plt.legend(handles = [lines[alg] for alg in dl.algs])
    else:
        plt.legend(handles = [lines[alg] for alg in dl.algs],handler_map={lines[gradient_name]: HandlerColormap(plt.cm.viridis,num_stripes=8)})

