"""Plotting utilities for the example scripts."""
from pathlib import Path
import seaborn as sns
import pandas as pd
import csv
import argparse


def plot_sb3_csvs(
        algorithm_to_csvs: dict[str, list[Path]],
        output_path: Path,
        ):
    # Plot each CSV
    longform = []
    columns = [
        "experiment_id",
        "timestep",
        "mean reward",
        "algorithm",
    ]
    experiment_id = 0
    for algorithm_id, csvs in algorithm_to_csvs.items():
        for csv_path in csvs:
            experiment_id += 1

            x = list[int]()
            y = list[float]()
            with open(csv_path) as csvfile:
                reader = csv.DictReader(csvfile)
                for row in reader:
                    xi = row['time/total_timesteps']
                    yi = row['eval/mean_reward']
                    try:
                        xi = int(xi)
                        yi = float(yi)
                        x.append(xi)
                        y.append(yi)
                    except:
                        continue

            for timestep_i, mean_reward_i in zip(x, y):
                row = (
                    experiment_id,
                    timestep_i,
                    mean_reward_i,
                    algorithm_id
                )
                longform.append(row)

    # Plot the responses for different events and regions
    df = pd.DataFrame(
        longform,
        index=list(range(len(longform))),
        columns=columns,
    )
    plot = sns.lineplot(
        x="timestep",
        y="mean reward",
        hue="algorithm",
        data=df,
    )

    fig = plot.get_figure()
    fig.savefig(output_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='Reward plotting',
        description=(
            'Plot the evaluation metrics in the given CSVs from'
            ' StableBaselines3 training.'
        ),
    )
    parser.add_argument(
        '--swmpo_csvs',
        type=Path,
        nargs='*',
        help='CSV files from StableBaselines3 training for SWMPO.',
    )
    parser.add_argument(
        '--biased_swmpo_csvs',
        type=Path,
        nargs='*',
        help='CSV files from StableBaselines3 training for Biased SWMPO.',
    )
    parser.add_argument(
        '--rl_csvs',
        type=Path,
        nargs='*',
        help='CSV files from StableBaselines3 training for RL.',
    )
    parser.add_argument(
        '--swmpo_ground_truth_csvs',
        type=Path,
        nargs='*',
        help='CSV files from StableBaselines3 training for SWMPO Ground Truth.',
    )
    parser.add_argument(
        '--biased_swmpo_ground_truth_csvs',
        type=Path,
        nargs='*',
        help='CSV files from StableBaselines3 training for Biased SWMPO Ground Truth.',
    )
    parser.add_argument(
        '--output_path',
        type=Path,
        required=True,
        help='PNG or SVG file to write with the plot.'
    )
    args = parser.parse_args()

    # Organize the CSVS for the different algorithms
    algorithm_to_csvs = dict(
        rl=args.rl_csvs if args.rl_csvs is not None else list(),
        swmpo=args.swmpo_csvs if args.swmpo_csvs is not None else list(),
        biased_swmpo=args.biased_swmpo_csvs if args.biased_swmpo_csvs is not None else list(),
        swmpo_ground_truth=args.swmpo_ground_truth_csvs if args.swmpo_ground_truth_csvs is not None else list(),
        biased_swmpo_ground_truth=args.biased_swmpo_ground_truth_csvs if args.biased_swmpo_ground_truth_csvs is not None else list(),
    )

    plot_sb3_csvs(
        algorithm_to_csvs=algorithm_to_csvs,
        output_path=args.output_path,
    )
