"""
This script plots the result of the human evaluation on Amazon Mechanical Turk, where
human participants chose between an image from ClimateGAN or from a different method.
"""
print("Imports...", end="")
from argparse import ArgumentParser
import yaml
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.special import comb
from scipy.stats import trim_mean
from tqdm import tqdm
from collections import OrderedDict
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import os

# -----------------------
# -----  Constants  -----
# -----------------------

comparables_dict = {
    "munit_flooded": "MUNIT",
    "cyclegan": "CycleGAN",
    "instagan": "InstaGAN",
    "instagan_copypaste": "Mask-InstaGAN",
    "painted_ground": "Painted ground",
}


# Colors
palette_colorblind = sns.color_palette("colorblind")
color_climategan = palette_colorblind[9]

palette_colorblind = sns.color_palette("colorblind")
color_munit = palette_colorblind[1]
color_cyclegan = palette_colorblind[2]
color_instagan = palette_colorblind[3]
color_maskinstagan = palette_colorblind[6]
color_paintedground = palette_colorblind[8]
palette_comparables = [
    color_munit,
    color_cyclegan,
    color_instagan,
    color_maskinstagan,
    color_paintedground,
]
palette_comparables_light = [
    sns.light_palette(color, n_colors=3)[1] for color in palette_comparables
]


def parsed_args():
    """
    Parse and returns command-line args

    Returns:
        argparse.Namespace: the parsed arguments
    """
    parser = ArgumentParser()
    parser.add_argument(
        "--input_csv",
        default="amt_omni-vs-other.csv",
        type=str,
        help="CSV containing the results of the human evaluation, pre-processed",
    )
    parser.add_argument("--output_dir", default=None, type=str, help="Output directory")
    parser.add_argument(
        "--dpi", default=200, type=int, help="DPI for the output images"
    )
    parser.add_argument(
        "--n_bs", default=1e6, type=int, help="Number of bootstrap samples"
    )
    parser.add_argument(
        "--bs_seed",
        default=17,
        type=int,
        help="Bootstrap random seed, for reproducibility",
    )

    return parser.parse_args()


if __name__ == "__main__":
    # -----------------------------
    # -----  Parse arguments  -----
    # -----------------------------
    args = parsed_args()
    print("Args:\n" + "\n".join([f"    {k:20}: {v}" for k, v in vars(args).items()]))

    # Determine output dir
    if args.output_dir is None:
        output_dir = Path(os.environ["SLURM_TMPDIR"])
    else:
        output_dir = Path(args.output_dir)
    if not output_dir.exists():
        output_dir.mkdir(parents=True, exist_ok=False)

    # Store args
    output_yml = output_dir / "args_human_evaluation.yml"
    with open(output_yml, "w") as f:
        yaml.dump(vars(args), f)

    # Read CSV
    df = pd.read_csv(args.input_csv)

    # Sort Y labels
    comparables = df.comparable.unique()
    is_climategan_sum = [
        df.loc[df.comparable == c, "is_climategan"].sum() for c in comparables
    ]
    comparables = comparables[np.argsort(is_climategan_sum)[::-1]]

    # Plot setup
    sns.set(style="whitegrid")
    plt.rcParams.update({"font.family": "serif"})
    plt.rcParams.update(
        {
            "font.serif": [
                "Computer Modern Roman",
                "Times New Roman",
                "Utopia",
                "New Century Schoolbook",
                "Century Schoolbook L",
                "ITC Bookman",
                "Bookman",
                "Times",
                "Palatino",
                "Charter",
                "serif" "Bitstream Vera Serif",
                "DejaVu Serif",
            ]
        }
    )
    fontsize = "medium"

    # Initialize the matplotlib figure
    fig, ax = plt.subplots(figsize=(10.5, 3), dpi=args.dpi)

    # Plot the total (right)
    sns.barplot(
        data=df.loc[df.is_valid],
        x="is_valid",
        y="comparable",
        order=comparables,
        orient="h",
        label="comparable",
        palette=palette_comparables_light,
        ci=None,
    )

    # Plot the left
    sns.barplot(
        data=df.loc[df.is_valid],
        x="is_climategan",
        y="comparable",
        order=comparables,
        orient="h",
        label="climategan",
        color=color_climategan,
        ci=99,
        n_boot=args.n_bs,
        seed=args.bs_seed,
        errcolor="black",
        errwidth=1.5,
        capsize=0.1,
    )

    # Draw line at 0.5
    y = np.arange(ax.get_ylim()[1] + 0.1, ax.get_ylim()[0], 0.1)
    x = 0.5 * np.ones(y.shape[0])
    ax.plot(x, y, linestyle=":", linewidth=1.5, color="black")

    # Change Y-Tick labels
    yticklabels = [comparables_dict[ytick.get_text()] for ytick in ax.get_yticklabels()]
    yticklabels_text = ax.set_yticklabels(
        yticklabels, fontsize=fontsize, horizontalalignment="right", x=0.96
    )
    for ytl in yticklabels_text:
        ax.add_artist(ytl)

    # Remove Y-label
    ax.set_ylabel(ylabel="")

    # Change X-Tick labels
    xlim = [0.0, 1.1]
    xticks = np.arange(xlim[0], xlim[1], 0.1)
    ax.set(xticks=xticks)
    plt.setp(ax.get_xticklabels(), fontsize=fontsize)

    # Set X-label
    ax.set_xlabel(None)

    # Change spines
    sns.despine(left=True, bottom=True)

    # Save figure
    output_fig = output_dir / "human_evaluation_rate_climategan.png"
    fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
