import torch
import matplotlib
import re
from copy import copy
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from scipy.interpolate import interp1d
from collections.abc import Iterable
from itertools import compress
from tabulate import tabulate

import _pickle as pickle

from robustopt_torch.funcutils import *
from robustopt_torch.costs import eucl_norm_sq

# Metric plotter helper functions

def is_valid_data_point(data_point):
    return type(data_point) is tuple and \
        len(data_point) <= 2 and len(data_point) > 0

def validate_data_format(data):
    if all(isinstance(elem, tuple) for elem in data): level = 1
    elif all(isinstance(elem, list) for elem in data): level = 2
    else: raise ValueError("Data elements at top level are not the same type.")

    valid_format = map(is_valid_data_point, flatteniter(data, levels = level))
    if not all(valid_format):
        raise ValueError("Data points are not tuples or are tuples" \
                         " with incorrect dimensions.")

    lengths = map(len, flatteniter(data, levels = level))
    if not allequal(lengths):
        raise ValueError("Data points for the same metric don't have equal length")


def write_to_file(file_name, metrics):
    with open(file_name, "wb") as fl:
        pickle.dump(metrics, fl)

def read_from_file(file_name):
    with open(file_name, "rb") as fl:
        return pickle.load(fl)

def get_x_and_y_from_tuples(data):
    d1, *d2 = zip(*data)
    d1, d2 = flattentolist(d1), flattentolist(d2)
    if not d2:
        x, y = [i for i in range(len(d1))], d1
    else:
        x, y = d1, d2
    return x, y

def max_col_layout(figure, num_plots, max_num_col):
    row_size = min(num_plots, max_num_col)
    num_row = - (-num_plots // row_size)
    axes = figure.subplots(num_row, row_size, squeeze = False)
    [ax.remove() for ax in axes.flat[num_plots:]]
    return axes.flat[:num_plots]

# Metric plotter class

class metricPlotter:

    def __init__(self, *metric_args, **metric_kwds):
        self._metrics = {}
        self._metric_to_axes = {}
        self._live_figure = None
        self._max_plots_per_row = 3
        self._auto_update_plots = False
        self._show_n_most_recent = 0
        self.append_to_metric(*metric_args, **metric_kwds)

    @property
    def metrics(self):
        return self._metrics

    @property
    def metric_to_axes(self):
        return self._metric_to_axes

    @property
    def live_figure(self):
        return self._live_figure

    @property
    def max_plots_per_row(self):
        return self._max_plots_per_row

    @max_plots_per_row.setter
    def max_plots_per_row(self, plots_per_row):
        if isinstance(plots_per_row, int) and plots_per_row > 0:
            self._max_plots_per_row = plots_per_row
        else:
            print(f"Invalid maximum number of plots per figure: {plots_per_row}")

    @property
    def auto_update_plots(self):
        return self._auto_update_plots

    @auto_update_plots.setter
    def auto_update_plots(self, auto_update):
        if isinstance(auto_update, bool):
            self._auto_update_plots = auto_update
            if auto_update:
                print("Requested auto updating of plots. pyplot.ion() may be necessary!")
        else:
            print("auto_update_plots must be a bool!")

    @property
    def show_n_most_recent(self):
        return self._show_n_most_recent

    @show_n_most_recent.setter
    def show_n_most_recent(self, n_most_recent):
        if isinstance(n_most_recent, int) and n_most_recent >= 0:
            self._show_n_most_recent = n_most_recent
        else:
            print("The number of recent metrics to show must be a nonnegative integer!")

    def _parse_args(self, args, kwds):
        arg_labels, arg_data = args[::2], args[1::2]
        if len(arg_labels) != len(arg_data):
            raise ValueError(f"Received an unbalanced number of metric and data pairs")
        update_plots = kwds.pop("update_plots", self.auto_update_plots)
        metric_labels = [*arg_labels, *kwds.keys()]
        metric_data = [*arg_data, *kwds.values()]
        return metric_labels, metric_data, update_plots

    def append_to_metric(self, *metric_args, **metric_kwds):
        metric_labels, metric_data, update_plots = self._parse_args(metric_args,
                                                                    metric_kwds)
        for metric, data in zip(metric_labels, metric_data):
            self._append_to_metric(metric, data)
        if update_plots: self.update_plots()

    def set_metric(self, *metric_args, **metrics_kwds):
        metric_labels, metric_data, update_plots = self._parse_args(metric_args,
                                                                    metric_kwds)
        self.clear_metrics(*metric_labels)
        for metric, data in zip(metric_labels, metric_data):
            self._append_to_metric(metric, data)
        if update_plots: self.update_plots()

    def _append_to_metric(self, metric, data):
        new_data = ensurelist(data)
        validate_data_format(new_data)
        updated_data = self.metrics.pop(metric, [])
        updated_data += new_data
        if updated_data: self.metrics[metric] = updated_data

    def clear_metrics(self, *metrics, update_plots = False):
        if not metrics:
            metrics = [*self.metrics.keys()]
        [self.metrics.pop(metric, None) for metric in metrics]
        if update_plots or self.auto_update_plots: self.update_plots()

    def set_up_figure_and_axes(self):
        num_plots = len(self.metrics)
        row_size = min(num_plots, self.max_plots_per_row)
        num_row = - (-num_plots // row_size)
        fig, axes = plt.subplots(num_row, row_size, squeeze = False)
        [ax.remove() for ax in axes.flat[num_plots:]]
        fig.set_tight_layout(True)
        return fig, axes.flat[:num_plots]

    def set_up_axes(self, figure):
        num_plots = len(self.metrics)
        row_size = min(num_plots, self.max_plots_per_row)
        num_row = - (-num_plots // row_size)
        axes = figure.subplots(num_row, row_size, squeeze = False)
        [ax.remove() for ax in axes.flat[num_plots:]]
        return axes.flat[:num_plots]

    def plot_metrics(self, fig_and_axes = None):
        fig, axes = self.set_up_figure_and_axes() if fig_and_axes is None else fig_and_axes
        num_axes = len(axes.flatten()) if len(axes) > 1 else len(axes)
        if len(self.metrics) != num_axes:
            raise ValueError("The number of metrics and the number of provided axes " \
                             "must be the same.")

        for metric, axis in zip(self.metrics.keys(), axes):
            self._plot_metric(metric, axis)
        return fig, axes

    def new_live_figure(self):
        fig, axes = self.set_up_figure_and_axes()
        self._live_figure = (fig, axes)
        self._metric_to_axes = {metric : axis for metric, axis in
                                zip(self.metrics.keys(), axes)}

    def _execute_redraw(self, figure):
        figure.canvas.draw()
        plt.pause(0.1)


    def update_plots(self):
        if self.live_figure is None:
            self.new_live_figure()
        elif self.metrics.keys() != self.metric_to_axes.keys():
            self.live_figure[0].close()
            self.new_live_figure()
        for metric, axis in self.metric_to_axes.items():
            self._plot_metric(metric, axis)

        self._execute_redraw(self.live_figure[0])

    def _plot_metric(self, metric, axis):
        axis.clear()
        axis.set_title(f"Metric: {metric}")
        axis.set_ylabel(f"{metric}")
        axis.set_xlabel(f"Iteration")

        data = self.metrics[metric]
        if type(data[0]) is list:
            for count, dataset in enumerate(data):
                x, y = get_x_and_y_from_tuples(dataset)
                axis.plot(x, y, marker='.', label=f"{metric} dataset {count}")
            axis.legend()
        else:
            x, y = get_x_and_y_from_tuples(data)
            axis.plot(x, y)
            if self._show_n_most_recent > 0:
                cell_text, row_labels = [" "] * 2, [" "] * 2
                cell_text.extend([f"{val:4.3e}"] for val in y[-self._show_n_most_recent:])
                row_labels.extend(f"Iteration {i}" for i in
                                  x[-self._show_n_most_recent:])
                axis.set_xlabel(None)
                axis.table(cellText = cell_text, rowLabels = row_labels, edges = "open")

# Particle plotter helper functions
def not_empty_particle(particle):
    return particle is not None and particle.numel() != 0

def format_dim(tensor):
    dim = tensor.dim()
    if dim == 1: raise ValueError("Only two dimensional particles are supported")
    elif dim == 2 and tensor.shape[1] == 2: return tensor
    else: raise ValueError("Particle has incorrect dimensions!")

def format_and_validate_particles(*particles):
    formatting_funcs = compose(format_dim, lambda x : x.detach())
    formatted_parts = (*map(formatting_funcs, particles),)
    if not allequal(part.shape[-1] for part in formatted_parts):
        raise ValueError("Particles must all have the same ambient dimension!")
    return formatted_parts

def format_and_validate_weights(particles, weights):
    format_wt = lambda x : x.detach().clone().squeeze_()
    formatted_wts = (*([format_wt(*wt)] if wt else wt for wt in weights),)
    if any(wt[0].min() < 0 for wt in formatted_wts if wt):
        raise ValueError("Received a negative weight for a particle!")
    if any(len(part)!= len(wt[0]) for part, wt in zip(particles, formatted_wts) if wt):
        raise ValueError("Number of particles and number of weights do not match!")
    return formatted_wts

# Particle plotter class

class particlePlotter:

    def __init__(self, *particles, **particles_with_labels):
        self._labeling_prefix = "Time"
        self._particles = {}
        self.add_particles(*particles, **particles_with_labels)

    @property
    def particles(self):
        return self._particles

    @property
    def labeling_prefix(self):
        return self._labeling_prefix

    @labeling_prefix.setter
    def labeling_prefix(self, label):
        if type(label) is str:
            self._labeling_prefix = label
        else:
            print(f"Invalid labeling prefix. Prefix is not a string!")

    def add_particles(self, *particles, **particles_with_labels):
        # Get the particles, their weights, and their labels
        if not particles and not particles_with_labels: return
        parts, wts, labels = self._parts_wts_and_labels(*particles,
                                                        **particles_with_labels)

        # Validate
        parts = format_and_validate_particles(*parts)
        wts = format_and_validate_weights(parts, wts)

        self._add_valid_particles(**dict((label, (part, *wt)) for part, wt,
                                         label in zip(parts, wts, labels)))

    def _parts_wts_and_labels(self, *particles, **particles_with_labels):
        # Labels for the unlabeled particles
        new_labels = (self.labeling_prefix + " " + str(len(self.particles) + i)
                      for i in range(len(particles)))
        labels = (*new_labels, *particles_with_labels.keys())

        # Get the particles and weights, remove empty particles
        data_tuples = ensuretuples((*particles, *particles_with_labels.values()))
        parts, wts = components((part, wt) for part, *wt in data_tuples)
        ne_parts = map(not_empty_particle, parts)

        return components(compress(zip(parts, wts, labels), ne_parts))

    def _add_valid_particles(self, **particles_with_labels):
        for label, particle_set in particles_with_labels.items():
            particles, *weights = particle_set
            curr_particles, *curr_weights = self.particles.get(label, (None, None))
            if curr_particles is None:
                self.particles[label] = (particles.clone(), *weights)
            else:
                if bool(curr_weights) != bool(weights):
                    raise ValueError(f"Weights for label: {label} are not compatible!")
                if curr_weights:
                    curr_weights = (torch.vstack((*curr_weights, *weights)),)
                self.particles[label] = (torch.vstack((curr_particles,
                                                       particles)), *curr_weights)

    def set_up_axes(self, figure, num_ax):
        axes = figure.subplots(1, num_ax, squeeze=False)

    def plot_particles(self, times = [], color = torch.as_tensor([1.0, 0, 0]),
                       marker_size = 50.0, weighted = False, fig = None, axes = None):
        if not times:
            times = sorted(self.particles.keys(),
                           key=particlePlotter.break_into_nums, reverse=True)[0:1]
        elif times == "all":
            times = sorted(self.particles.keys(), key=particlePlotter.break_into_nums)

        if fig is None:
            fig, axes = plt.subplots(1, len(times), squeeze=False)
        elif axes is None:
            fig, axes = fig, fig.subplots(1, len(times), squeeze=False)
        elif len(times) != len(axes.flatten()):
            raise ValueError("Number of times to plot and number of axes " \
                             "are not equal.")

        for time, axis in zip(times, axes.flatten()):
            points, *wts = self.particles[time]
            if not weighted: wts = ()
            self._plot_points(points, axis, *wts, color = color, marker_size =
                              marker_size)

            axis.set_title(f"{time}")

        return fig, axes

    def _plot_points(self, points, axis, weights = None, color =
                     torch.as_tensor([1.0, 0, 0]), marker_size = 50.0):
        x, y = torch.chunk(points, 2, dim=1)
        colors = torch.tile(color, (len(points), 1))
        if weights is not None:
            if len(points) != len(weights):
                raise ValueError("Number of data points and number of weights " \
                                 "do not match!")
            norm_wts = weights / weights.sum()
            colors = torch.hstack((colors, norm_wts.unsqueeze(-1)))
        axis.scatter(x, y, c=colors, marker='.', s=marker_size)

    @staticmethod
    def break_into_nums(target_str):
        num_re = re.compile(r"\d+")
        return [number for number in map(float, num_re.findall(target_str))]


class Animator:
    def __init__(self, filename, figure, plotter):
        self.filename = filename
        self.figure = figure
        self.plotter = plotter

    def animate(self, frame):
        try:
            data = read_from_file(self.filename)
        except FileNotFoundError:
            return

        self.figure.clear()
        self.plotter(self.figure, data)

def plot_realtime(filenames, plotters, intervals = None):
    filenames = ensurelist(filenames)
    plotters = ensurelist(plotters)
    if len(filenames) != len(plotters):
        raise ValueError("Number of files and number of plotters must be equal!")

    if intervals is None: intervals = 1000
    intervals = ensurelist(intervals)
    if len(intervals) == 1: intervals *= len(filenames)

    figures = [plt.figure() for fl in filenames]
    for fig in figures: fig.set_tight_layout(True)
    animators = [Animator(fname, fig, plotter) for fname, fig, plotter in
                 zip(filenames, figures, plotters)]
    animations = [FuncAnimation(fig, animator.animate, interval = it) for fig,
                  animator, it in zip(figures, animators, intervals)]
    plt.show()

def animateMetricPlotter(figure, data, show_n_most_recent = 0):
    m_plotter = metricPlotter(**data)
    m_plotter.show_n_most_recent = show_n_most_recent
    axes = m_plotter.set_up_axes(figure)
    m_plotter.plot_metrics((figure, axes))

def animateParticlePlotter(figure, data):
    p_plotter = particlePlotter(**data)
    p_plotter.plot_particles(fig = figure)


class particlesWithReference:
    def __init__(self, parts):
        self.reference = parts["Reference"]

    def animateParticlePlotterWithRef(self, figure, data):
        p_plotter = particlePlotter(**data)
        fig, axes = p_plotter.plot_particles(fig = figure, marker_size = 100.0)

        ref_parts, *ref_wts = self.reference
        axis = axes.flatten()[0]
        p_plotter._plot_points(ref_parts, axis, color =
                               torch.as_tensor([0.0, 0.0, 0.0]))

        axis.set_xlim(-13, 13)
        axis.set_ylim(-13, 13)

def metric_callback(iteration_variables, metric_plotter, metric_calc_funcs,
                    log_file = None, print_most_recent = False):

    for func in metric_calc_funcs: func(iteration_variables, metric_plotter)

    if log_file is not None:
        write_to_file(log_file, metric_plotter.metrics)

    if print_most_recent:
        add_iter_num = lambda metric, iter_num : f"{metric} (Iter: {iter_num[0]})" \
            if isinstance(metric, str) and len(iter_num) > 0 else metric
        last_metrics = {add_iter_num(metric, val[-1][0:-1]) : val[-1][-1:] \
                        for metric, val in metric_plotter.metrics.items() \
                        if isinstance(val[-1], tuple)}
        print(tabulate(last_metrics, headers = "keys"))

# Density plotting tools

def gauss_dens(points, parts, weights = None, bandwidth = 1.0):
    if weights is None:
        weights = torch.ones(len(parts))
    weights = weights / weights.sum()
    pi_est = torch.acos(torch.zeros(1)) * 2.0
    scaling = torch.exp(- points.shape[-1] / 2.0 * torch.log(2.0 * pi_est * bandwidth))
    return torch.matmul(scaling * torch.exp(-eucl_norm_sq(points, parts) /
                                            bandwidth), weights)

def get_dens_plot_data(particles, x_lim, y_lim, num_pts_per_dim, particle_weights = None,
                       bandwidth = 1.0):
    x, y = torch.meshgrid(torch.linspace(*x_lim, num_pts_per_dim),
                          torch.linspace(*y_lim, num_pts_per_dim))
    vector_data = torch.stack((torch.flatten(x), torch.flatten(y)), dim = -1)
    z = gauss_dens(vector_data, particles, particle_weights, bandwidth =
                   bandwidth).reshape(num_pts_per_dim, num_pts_per_dim)
    return x, y, z

def decon_den_plot(ref_part_and_wt, iterate_part_and_wt, x_lim, y_lim,
                   num_pts_per_dim, max_num_col = 3,
                   ref_bandwidth = 1.0, iter_bandwidth = 1.0,
                   ref_den_cutoff = torch.exp(-5 * torch.ones(1)).item(),
                   part_den_cutoff = torch.exp(-5 * torch.ones(1)).item(),
                   ref_den_color_palette = None,
                   part_den_color_palette = None):

    ref_parts, *ref_wts = ref_part_and_wt
    ref_wts = None if ref_wts == [] else ref_wts[0]
    ref_x, ref_y, ref_z = get_dens_plot_data(ref_parts, x_lim, y_lim,
                                             num_pts_per_dim, particle_weights =
                                             ref_wts, bandwidth = ref_bandwidth)
    if ref_den_color_palette is None:
        ref_den_color_palette = copy(plt.get_cmap('Reds'))
        ref_den_color_palette.set_under('white', 0.0)

    if part_den_color_palette is None:
        part_den_color_palette = copy(plt.get_cmap('viridis'))
        part_den_color_palette.set_under('white', 0.0)

    num_plots = len(iterate_part_and_wt)
    fig = plt.figure(figsize=(9,9))
    fig.set_tight_layout(True)
    axes = max_col_layout(fig, num_plots, max_num_col)

    for label, parts_and_wts, ax in zip(*zip(*iterate_part_and_wt.items()), axes):
        parts, *wts = parts_and_wts
        if wts == []: wts = None
        x, y, z = get_dens_plot_data(parts, x_lim, y_lim, num_pts_per_dim,
                                     particle_weights = wts, bandwidth =
                                     iter_bandwidth)
        ax.contourf(ref_x, ref_y, ref_z, vmin = ref_den_cutoff, cmap
                  = ref_den_color_palette)
        ax.contour(x, y, z, vmin = part_den_cutoff, cmap =
                   part_den_color_palette)
        ax.set_title(label)
    return fig, axes

def plot_averaged_metrics(metric_replications, axis, num_interp_pts = 200,
                          mean_curve = True, mean_curve_options = {},
                          individual_curves = False, individual_curve_options =
                          {}, confidence_band_width = 1.0,
                          confidence_band_options = {},color="blue"):

    iter_interp, metric_interp = get_interpolated_metrics(metric_replications,
                                                          num_interp_pts)
    stds, means = torch.std_mean(metric_interp, dim = 0)

    ic_options = {"color" : "grey", "alpha" : 0.4}
    ic_options.update(individual_curve_options)
    if individual_curves:
        for replication in metric_interp:
            axis.plot(iter_interp, replication, **ic_options)

    if mean_curve:
        mc_options = {"color" : color,
                      "label" : "mean"}
        mc_options.update(mean_curve_options)
        axis.plot(iter_interp, means, '-', **mc_options)

    if confidence_band_width is not None and confidence_band_width > 0.0:

        cb_options = {"alpha" : 0.2,
                      "label" : f"{confidence_band_width} standard deviations"}
        cb_options.update({key:val for key, val in mc_options.items() if key in
                           {"color"}})
        cb_options.update(confidence_band_options)
        axis.fill_between(iter_interp, means - confidence_band_width * stds,
                          means + confidence_band_width * stds, **cb_options)
    return axis

def get_interpolated_metrics(metric_replications, num_points):
    metric_interps = []
    range_min, range_max = float("-inf"), float("inf")
    for replication in metric_replications:
        iters, metric = zip(*replication)
        if len(iters) != len(metric):
            raise ValueError("Incorrectly formatted input!")

        tensor_of_iters = torch.as_tensor(iters, dtype=torch.float)
        metric_interps.append(interp1d(tensor_of_iters,
                                       torch.as_tensor(metric, dtype=torch.float),
                                       kind="linear"))
        range_min = max(range_min, tensor_of_iters.min())
        range_max = min(range_max, tensor_of_iters.max())

    interp_pts = torch.linspace(range_min, range_max, num_points)
    tensor_of_metrics = \
    torch.stack(tuple(torch.as_tensor(met_interp(interp_pts)) for met_interp in
                      metric_interps), dim = 0)
    return interp_pts, tensor_of_metrics

def get_uninterpolated_metrics(metric_replications):
    metrics = []
    for count, replication in enumerate(replication_metrics):
        iters, metric = zip(*replication)
        if len(iters) != len(metric):
            raise ValueError("Incorrectly formatted input!")
        if count == 0:
            past_iters = iters
        else:
            if iters != past_iters:
                raise ValueError("Metrics must all have the same iteration indexes! " \
                                 "Use for interpolation!")
        metrics.append(metric)

    tensor_of_iters = torch.as_tensor(iters)
    tensor_of_metrics = torch.stack(tuple(torch.as_tensor(met, dtype=torch.float) for
                                     met in metrics), dim = 0)
    return tensor_of_iters, tensor_of_metrics
