"""Plotting utilities for the examples."""
from collections import defaultdict
from pathlib import Path
import numpy as np
import seaborn as sns
import matplotlib as mpl
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
import pandas as pd
import traceback


DifficultyValue = int
SolverID = str
SolverTime = float
Quality = float
ExperimentID = int
Results = dict[DifficultyValue, dict[SolverID, list[SolverTime | None]]]
QualityResults = list[tuple[
    DifficultyValue, SolverID, SolverTime, Quality, ExperimentID
]]


def plot_success_plot(
        results: Results,
        output_path: Path,
        ):
    # Plot success plot
    long_form_results = defaultdict(dict)
    row_id = 0
    for difficulty in results.keys():
        for solver_id in results[difficulty].keys():
            successful_times = [
                t
                for t in results[difficulty][solver_id]
                if t is not None
                ]
            long_form_results["solver"][row_id] = solver_id
            long_form_results["difficulty"][row_id] = difficulty
            long_form_results["success_n"][row_id] = len(successful_times)
            row_id += 1
    plot = sns.relplot(
        data=long_form_results,
        kind="line",
        x="difficulty",
        y="success_n",
        hue="solver",
    )
    plot.figure.savefig(output_path)


def plot_cumulative_success_plot(
        results: Results,
        output_path: Path,
        ):
    long_form_results = defaultdict(dict)
    row_id = 0
    cumulative_success_n = defaultdict(lambda: 0)
    solver_ids = list(list(results.values())[0].keys())
    for solver_id in solver_ids:
        for difficulty in sorted(map(lambda r: int(r), results.keys())):
            successful_times = [
                t
                for t in results[difficulty][solver_id]
                if t is not None
                ]
            cumulative_success_n[solver_id] += len(successful_times)
            long_form_results["solver"][row_id] = solver_id
            long_form_results["difficulty"][row_id] = difficulty
            long_form_results["cumulative_success_n"][row_id] = cumulative_success_n[solver_id]
            row_id += 1
    plot = sns.relplot(
        data=long_form_results,
        kind="line",
        x="difficulty",
        y="cumulative_success_n",
        hue="solver",
    )
    plot.figure.savefig(output_path)


def plot_time(
        results: Results,
        output_path: Path,
        ):
    long_form_results = defaultdict(dict)
    row_id = 0
    for difficulty in results.keys():
        for solver_id in results[difficulty].keys():
            for time in [
                    t
                    for t in results[difficulty][solver_id]
                    if t is not None
                    ]:
                long_form_results["solver"][row_id] = solver_id
                long_form_results["difficulty"][row_id] = difficulty
                long_form_results["time"][row_id] = time
                row_id += 1
    plot = sns.relplot(
        data=long_form_results,
        kind="line",
        x="difficulty",
        y="time",
        hue="solver",
        errorbar=("pi", 50),
    )
    plot.figure.savefig(output_path)


def plot_overlay_success_time(
        results: Results,
        output_path: Path,
        ):
    # Create a matplotlib figure
    fig = Figure()
    _ = FigureCanvas(fig)
    ax_time = fig.add_subplot()
    ax_success = ax_time.twinx()

    solver_ids = list(results.values())[0].keys()
    difficulties = [int(n) for n in results.keys()]
    for solver_id in solver_ids:
        x_time_mean = list()
        y_time_mean = list()
        y_time_error_lower = list()
        y_time_error_upper = list()
        x_success = list()
        y_success = list()
        for difficulty in sorted(difficulties):
            times = list(results[difficulty][solver_id])
            successful_times = [
                t
                for t in times
                if t is not None
                ]
            if len(successful_times) > 0:
                # Time data
                mean_time = sum(successful_times)/len(successful_times)
                x_time_mean.append(difficulty)
                y_time_mean.append(mean_time)

                # Time error data
                low_error, hi_error = np.quantile(
                    successful_times, [0.25, 0.75]
                )
                y_time_error_lower.append(low_error)
                y_time_error_upper.append(hi_error)

                # Success data
                success_rate = len(successful_times)/len(times)*100
                x_success.append(difficulty)
                y_success.append(success_rate)
        ax_time.plot(
            x_time_mean, y_time_mean,
            label=solver_id
        )
        ax_time.fill_between(
            x_time_mean,
            y_time_error_upper,
            y_time_error_lower,
            alpha=0.4,
        )
        ax_success.plot(
            x_success,
            y_success,
        )

    # Adjust plot
    ax_time.set_xlabel('Difficulty')
    ax_time.set_ylabel('Solve time (s)')
    ax_success.set_ylabel('Success rate (%)')

    # Save plot
    fig.legend()
    fig.savefig(output_path)


def plot_overlay_cumulative_success_mean_time(
        results: Results,
        output_path: Path,
        ):
    # Create a matplotlib figure
    fig = Figure()
    _ = FigureCanvas(fig)
    ax_time = fig.add_subplot()
    ax_success = ax_time.twinx()

    solver_ids = list(results.values())[0].keys()
    difficulties = [int(n) for n in results.keys()]
    for solver_id in solver_ids:
        x_time_mean = list()
        y_time_mean = list()
        y_time_error_lower = list()
        y_time_error_upper = list()
        x_success = list()
        y_success = list()
        cumulative_success = 0
        for difficulty in sorted(difficulties):
            times = list(results[difficulty][solver_id])
            successful_times = [
                t
                for t in times
                if t is not None
                ]
            if len(successful_times) > 0:
                # Time data
                mean_time = sum(successful_times)/len(successful_times)
                x_time_mean.append(difficulty)
                y_time_mean.append(mean_time)

                # Time error data
                low_error, hi_error = np.quantile(
                    successful_times, [0.25, 0.75]
                )
                y_time_error_lower.append(low_error)
                y_time_error_upper.append(hi_error)

            # Success data
            cumulative_success += len(successful_times)
            x_success.append(difficulty)
            y_success.append(cumulative_success)
        ax_time.plot(
            x_time_mean, y_time_mean,
            label=solver_id
        )
        ax_time.fill_between(
            x_time_mean,
            y_time_error_upper,
            y_time_error_lower,
            alpha=0.4,
        )
        ax_success.plot(
            x_success,
            y_success,
        )

    # Adjust plot
    ax_time.set_xlabel('Difficulty')
    ax_time.set_ylabel('Mean time (s)')
    ax_success.set_ylabel('Cumulative successful tasks')

    # Save plot
    fig.legend()
    fig.savefig(output_path)


def plot_overlay_cumulative_success_cumulative_time(
        results: Results,
        output_path: Path,
        ):
    # Create a matplotlib figure
    fig = Figure()
    _ = FigureCanvas(fig)
    ax_time = fig.add_subplot()
    ax_success = ax_time.twinx()

    solver_ids = list(results.values())[0].keys()
    difficulties = [int(n) for n in results.keys()]
    for solver_id in solver_ids:
        x_time = list()
        y_time = list()
        x_success = list()
        y_success = list()
        cumulative_success = 0
        cumulative_time = 0
        for difficulty in sorted(difficulties):
            times = list(results[difficulty][solver_id])
            successful_times = [
                t
                for t in times
                if t is not None
                ]
            # Time data
            cumulative_time += sum(successful_times, start=0)
            x_time.append(difficulty)
            y_time.append(cumulative_time)

            # Success data
            cumulative_success += len(successful_times)
            x_success.append(difficulty)
            y_success.append(cumulative_success)
        ax_time.plot(
            x_time, y_time,
            label=solver_id
        )
        ax_success.plot(
            x_success,
            y_success,
            linestyle='dashed',
        )

    # Adjust plot
    ax_time.set_xlabel('Difficulty')
    ax_time.set_ylabel('Cumulative time (s)')
    ax_success.set_ylabel('Cumulative successful tasks')

    # Empty lines to add success and time line-styles to the legend
    ax_success.plot([], [], label='success', c="black", ls='--')[0]
    ax_success.plot([], [], label='time (s)', c="black")[0]

    # Save plot
    fig.legend()
    fig.savefig(output_path)


def plot_overlay_cumulative_success_rate_cumulative_time(
        results: Results,
        output_path: Path,
        ):
    # Create a matplotlib figure
    fig = Figure()
    _ = FigureCanvas(fig)
    ax_time = fig.add_subplot()
    ax_success = ax_time.twinx()

    solver_ids = list(results.values())[0].keys()
    difficulties = [int(n) for n in results.keys()]
    cmap = mpl.colormaps['tab10']
    for i, solver_id in enumerate(solver_ids):
        x_time = list()
        y_time = list()
        x_success = list()
        y_success = list()
        cumulative_success = 0
        cumulative_time = 0
        total_tasks = sum(
            1
            for times in results.values()
            for _ in times[solver_id]
        )
        for difficulty in sorted(difficulties):
            times = list(results[difficulty][solver_id])
            successful_times = [
                t
                for t in times
                if t is not None
                ]

            # Time data
            cumulative_time += sum(successful_times, start=0)
            x_time.append(difficulty)
            y_time.append(cumulative_time)

            # Success data
            cumulative_success += len(successful_times)
            cumulative_success_rate = cumulative_success/total_tasks*100
            x_success.append(difficulty)
            y_success.append(cumulative_success_rate)
        ax_time.plot(
            x_time,
            y_time,
            label=solver_id,
            linestyle='solid',
            color=cmap(i),
        )
        ax_success.plot(
            x_success,
            y_success,
            linestyle='dashed',
            color=cmap(i),
        )
        ax_success.set_ylim(bottom=0.0, top=100)

    # Adjust plot
    ax_time.set_xlabel('Difficulty')
    ax_time.set_ylabel('Cumulative time (seconds, dashed)')
    ax_time.tick_params(axis='y')
    ax_success.set_ylabel('Cumulative successful tasks % (solid)')
    ax_success.tick_params(axis='y')

    # Empty lines to add success and time line-styles to the legend
    ax_success.plot([], [], label='success', c="black", ls='--')[0]
    ax_success.plot([], [], label='time (s)', c="black")[0]

    # Save plot
    fig.legend()
    fig.savefig(output_path)


def plot_time_vs_success_rate(
        results: Results,
        output_path: Path,
        ):
    # Identify success times
    times = defaultdict(lambda: list[None | float]())
    for difficulty in results.keys():
        for solver_id in results[difficulty].keys():
            times[solver_id].extend([
                t
                for t in results[difficulty][solver_id]
            ])

    # Identify the success rate for each success time
    long_form_results = defaultdict(dict)
    row_id = 0
    max_t = max((max((t for t in ts if t is not None), default=1.0) for ts in times.values()), default=1.0)
    for solver_id, solver_times in times.items():
        for t in solver_times + [0.0, max_t]:
            if t is None:
                continue
            success_in_time = [
                1 if (t2 is not None and t2 <= t) else 0
                for t2 in solver_times
            ]
            success_rate = 100*sum(success_in_time)/len(solver_times)
            long_form_results["solver"][row_id] = solver_id
            long_form_results["time [s]"][row_id] = t
            long_form_results["success rate [%]"][row_id] = success_rate
            row_id += 1
    sns.set_theme()
    plot = sns.relplot(
        data=long_form_results,
        kind="line",
        x="time [s]",
        y="success rate [%]",
        hue="solver",
    )
    plot.set(ylim=(-5, 105))
    plot.figure.savefig(output_path)


def plot_results(
        results: Results,
        output_dir: Path,
        ):
    """`output_dir` must be an existing directory. Existing files may be
    overwritten."""
    plot_success_plot(
        results,
        output_dir/"success.svg"
    )
    plot_cumulative_success_plot(
        results,
        output_dir/"cumulative_success.svg"
    )
    plot_time(
        results,
        output_dir/"mean_times.svg"
    )
    plot_overlay_success_time(
        results,
        output_dir/"mean_time_success__overlay.svg"
    )
    plot_overlay_cumulative_success_mean_time(
        results,
        output_dir/"mean_times_cumulative_success_overlay.svg"
    )
    plot_overlay_cumulative_success_cumulative_time(
        results,
        output_dir/"cumulative_time_cumulative_success_overlay.svg"
    )
    plot_overlay_cumulative_success_rate_cumulative_time(
        results,
        output_dir/"cumulative_time_cumulative_success_rate_overlay.svg"
    )
    plot_time_vs_success_rate(
        results,
        output_dir/"success_rate_vs_time.svg"
    )


def plot_quality_vs_time(
        results: QualityResults,
        output_path: Path,
        ):
    # Identify the success rate for each success time
    long_form_results = defaultdict(dict)
    row_id = 0
    results = sorted(results, key=lambda r: r[2])
    for (difficulty, solver_id, solver_time, solver_quality, experiment_id) in results:
        print(solver_id, solver_time, solver_quality)
        # Enforce quality bounds
        solver_quality = max(0.0, min(1.0, solver_quality))
        long_form_results["solver"][row_id] = solver_id
        long_form_results["time [s]"][row_id] = solver_time
        long_form_results["quality"][row_id] = solver_quality
        long_form_results["difficulty"][row_id] = difficulty
        long_form_results["experiment_id"][row_id] = experiment_id
        row_id += 1
    sns.set_theme()
    plot = sns.relplot(
        data=pd.DataFrame(long_form_results),
        x="time [s]",
        y="quality",
        hue="solver",
        kind="line",
    )
    plot.set(ylim=(-0.05, 1.05))
    plot.figure.savefig(output_path)


def plot_quality_results(
        results: QualityResults,
        output_dir: Path,
        ):
    """`output_dir` must be an existing directory. Existing files may be
    overwritten."""
    plot_quality_vs_time(
        get_preprocessed_quality_results(results),
        output_dir/"quality_vs_time.svg"
    )


def get_preprocessed_quality_results(results: QualityResults) -> QualityResults:
    """Pre-process the given results to e.g., make sure every experiment
    starts at zero quality."""
    # Helper functions
    def get_difficulty(e: ExperimentID) -> DifficultyValue:
        # Identify difficulty value of experiment
        difficulty_values = set(
            dv for dv, _, _, _, eid2 in results if eid2 == e
        )
        assert len(difficulty_values) == 1
        dv = list(difficulty_values)[0]
        return dv

    def get_solver_id(e: ExperimentID) -> SolverID:
        # Identify solver ID
        solver_ids = set(
            sid for _, sid, _, _, eid2 in results if eid2 == e
        )
        assert len(solver_ids) == 1
        sid = list(solver_ids)[0]
        return sid

    def get_experiment_times(e: ExperimentID) -> set[SolverTime]:
        experiment_times = set(
            t for _, _, t, _, eid2 in results if eid2 == e
        )
        return experiment_times

    # Get the experiment quality at every logged time:
    # Experiments log qualities at different times, and we need to recover
    # the distribution of qualities at every timestamp to get an accurate
    # characterization of the quality distribution as a function time.
    # First, identify timestamps
    timestamps = sorted(
        timestamp
        for _, _, timestamp, _, _ in results
    )
    # If there are too many timestamps, sample at equal intervals
    if len(timestamps) > 1000:
        indices = list(range(0, len(timestamps), len(timestamps)//1000))
        if len(timestamps)-1 not in timestamps:
            indices.append(len(timestamps)-1)
        timestamps = [
            timestamps[i]
            for i in indices
        ]
    timestamps = set(timestamps)

    #### speed-up
    results_by_eid = defaultdict(list)
    ts_by_eid = defaultdict(list)
    for _, _, t, quality, eid in results:
        results_by_eid[eid].append((t, quality))
        ts_by_eid[eid].append(t)
    processed_quality_by_eid = dict()
    for eid in list(results_by_eid.keys()):
        results_by_eid[eid] = sorted(results_by_eid[eid], key=lambda tq: tq[0])
        ts_by_eid[eid] = sorted(ts_by_eid[eid])

        # Set monotonic quality at each recorded timestemp
        processed_quality_by_eid[eid] = list()
        current_quality = 0.0
        for (_, q) in results_by_eid[eid]:
            current_quality = max(current_quality, q)
            processed_quality_by_eid[eid].append(current_quality)
    #### /speed-up

    # Then, define the quality at any timestamp for a given experiment
    from bisect import bisect_right
    def get_quality(t: SolverTime, experiment_id: ExperimentID) -> Quality:
        i = bisect_right(ts_by_eid[experiment_id], t)
        if i == 0:
            return 0.0
        if i:
            return processed_quality_by_eid[experiment_id][i-1]
        return processed_quality_by_eid[experiment_id][-1]

    # Build a list of pre-processed results
    new_results = list()

    # Make sure every experiment has a quality at time zero
    experiment_ids = set(eid for _, _, _, _, eid in results)
    for eid in experiment_ids:
        dv = get_difficulty(eid)
        sid = get_solver_id(eid)

        # Put start timestamp with zero quality
        results.append((dv, sid, 0.0, 0.0, eid))

    # Finally, construct new results logging the quality at every timestep
    # for every experiment
    for eid in experiment_ids:
        dv = get_difficulty(eid)
        sid = get_solver_id(eid)

        for t in sorted(timestamps):
            q = get_quality(t, eid)
            new_results.append((dv, sid, t, q, eid))

    return new_results


if __name__ == "__main__":
    import argparse
    import json
    parser = argparse.ArgumentParser(description='Plot experiment results')
    parser.add_argument('--time_results_json', type=Path, required=True)
    parser.add_argument('--quality_results_json', type=Path, required=True)
    parser.add_argument('--output_dir', type=Path, required=True)
    args = parser.parse_args()
    args.output_dir.mkdir(exist_ok=True)

    # Open time CSV
    with open(args.time_results_json, "rt") as fp:
        raw_results = json.load(fp)

    # Format results object correctly
    results = dict()
    for difficulty, data in raw_results.items():
        results[int(difficulty)] = data

    # Plot time results
    time_output_dir = args.output_dir/"time_figures"
    time_output_dir.mkdir(exist_ok=True)

    # Open quality CSV
    with open(args.quality_results_json, "rt") as fp:
        quality_results = json.load(fp)

    # Plot quality results
    quality_output_dir = args.output_dir/"quality_figures"
    quality_output_dir.mkdir(exist_ok=True)
    try:
        plot_results(results, time_output_dir)
    except Exception:
        print(traceback.format_exc())
        pass
    try:
        plot_quality_results(quality_results, quality_output_dir)
    except Exception:
        print(traceback.format_exc())
        pass
