"""
Compute metrics of the performance of the masker using a set of ground-truth labels

run eval_masker.py --model "path/to/a/model"

"""
print("Imports...", end="")
import os
import sys

GROUP = os.environ.get("CLIMATEGAN_GROUP")
if GROUP is None:
    print("CLIMATEGAN_GROUP is unknown. Please set env variable")
    sys.exit()

import os.path
from argparse import ArgumentParser
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from comet_ml import Experiment
import torch
import yaml
from skimage.color import rgba2rgb
from skimage.io import imread, imsave
from skimage.transform import resize
from skimage.util import img_as_ubyte
from torchvision.transforms import ToTensor

from climategan.data import encode_mask_label
from climategan.eval_metrics import (
    masker_classification_metrics,
    get_confusion_matrix,
    edges_coherence_std_min,
    boxplot_metric,
    clustermap_metric,
)
from climategan.transforms import PrepareTest
from climategan.trainer import Trainer
from climategan.utils import find_images

dict_metrics = {
    "names": {
        "tpr": "TPR, Recall, Sensitivity",
        "tnr": "TNR, Specificity, Selectivity",
        "fpr": "FPR",
        "fpt": "False positives relative to image size",
        "fnr": "FNR, Miss rate",
        "fnt": "False negatives relative to image size",
        "mpr": "May positive rate (MPR)",
        "mnr": "May negative rate (MNR)",
        "accuracy": "Accuracy (ignoring may)",
        "error": "Error (ignoring may)",
        "f05": "F0.05 score",
        "precision": "Precision",
        "edge_coherence": "Edge coherence",
        "accuracy_must_may": "Accuracy (ignoring cannot)",
    },
    "threshold": {
        "tpr": 0.95,
        "tnr": 0.95,
        "fpr": 0.05,
        "fpt": 0.01,
        "fnr": 0.05,
        "fnt": 0.01,
        "accuracy": 0.95,
        "error": 0.05,
        "f05": 0.95,
        "precision": 0.95,
        "edge_coherence": 0.02,
        "accuracy_must_may": 0.5,
    },
    "key_metrics": ["f05", "error", "edge_coherence", "mnr"],
}

print("Ok.")


def parsed_args():
    """Parse and returns command-line args

    Returns:
        argparse.Namespace: the parsed arguments
    """
    parser = ArgumentParser()
    parser.add_argument("--model", type=str, help="Path to a pre-trained model")
    parser.add_argument(
        "--images_dir",
        default=f"/miniscratch/_groups/{GROUP}/data/climategan/masker-test-set/imgs",
        type=str,
        help="Directory containing the original test images",
    )
    parser.add_argument(
        "--labels_dir",
        default=f"/miniscratch/_groups/{GROUP}/data/climategan/masker-test-set/labels",
        type=str,
        help="Directory containing the labeled images",
    )
    parser.add_argument(
        "--image_size",
        default=640,
        type=int,
        help="The height and weight of the pre-processed images",
    )
    parser.add_argument(
        "--max_files", default=-1, type=int, help="Limit loaded samples"
    )
    parser.add_argument(
        "--bin_value", default=0.5, type=float, help="Mask binarization threshold"
    )
    parser.add_argument(
        "-y",
        "--yaml",
        default=None,
        type=str,
        help="load a yaml file to parametrize the evaluation",
    )
    parser.add_argument(
        "-t", "--tags", nargs="*", help="Comet.ml tags", default=[], type=str
    )
    parser.add_argument(
        "-p",
        "--plot",
        action="store_true",
        default=False,
        help="Plot masker images & their metrics overlays",
    )
    parser.add_argument(
        "--no_paint",
        action="store_true",
        default=False,
        help="Do not log painted images",
    )
    parser.add_argument(
        "--write_metrics",
        action="store_true",
        default=False,
        help="If True, write CSV file and maps images in model's path directory",
    )
    parser.add_argument(
        "--load_metrics",
        action="store_true",
        default=False,
        help="If True, load predictions and metrics instead of re-computing",
    )
    parser.add_argument(
        "--prepare_torch",
        action="store_true",
        default=False,
        help="If True, pre-process images as torch tensors",
    )
    parser.add_argument(
        "--output_csv",
        default=None,
        type=str,
        help="Filename of the output CSV with the metrics of all models",
    )

    return parser.parse_args()


def uint8(array):
    return array.astype(np.uint8)


def crop_and_resize(image_path, label_path):
    """
    Resizes an image so that it keeps the aspect ratio and the smallest dimensions
    is 640, then crops this resized image in its center so that the output is 640x640
    without aspect ratio distortion

    Args:
        image_path (Path or str): Path to an image
        label_path (Path or str): Path to the image's associated label

    Returns:
        tuple((np.ndarray, np.ndarray)): (new image, new label)
    """

    img = imread(image_path)
    lab = imread(label_path)

    # if img.shape[-1] == 4:
    #     img = uint8(rgba2rgb(img) * 255)

    # TODO: remove (debug)
    if img.shape[:2] != lab.shape[:2]:
        print(
            "\nWARNING: shape mismatch: im -> ({}) {}, lab -> ({}) {}".format(
                img.shape[:2], image_path.name, lab.shape[:2], label_path.name
            )
        )
        # breakpoint()

    # resize keeping aspect ratio: smallest dim is 640
    i_h, i_w = img.shape[:2]
    if i_h < i_w:
        i_size = (640, int(640 * i_w / i_h))
    else:
        i_size = (int(640 * i_h / i_w), 640)

    l_h, l_w = img.shape[:2]
    if l_h < l_w:
        l_size = (640, int(640 * l_w / l_h))
    else:
        l_size = (int(640 * l_h / l_w), 640)

    r_img = resize(img, i_size, preserve_range=True, anti_aliasing=True)
    r_img = uint8(r_img)

    r_lab = resize(lab, l_size, preserve_range=True, anti_aliasing=False, order=0)
    r_lab = uint8(r_lab)

    # crop in the center
    H, W = r_img.shape[:2]

    top = (H - 640) // 2
    left = (W - 640) // 2

    rc_img = r_img[top : top + 640, left : left + 640, :]
    rc_lab = (
        r_lab[top : top + 640, left : left + 640, :]
        if r_lab.ndim == 3
        else r_lab[top : top + 640, left : left + 640]
    )

    return rc_img, rc_lab


def plot_images(
    output_filename,
    img,
    label,
    pred,
    metrics_dict,
    maps_dict,
    edge_coherence=-1,
    pred_edge=None,
    label_edge=None,
    dpi=300,
    alpha=0.5,
    vmin=0.0,
    vmax=1.0,
    fontsize="xx-small",
    cmap={
        "fp": "Reds",
        "fn": "Reds",
        "may_neg": "Oranges",
        "may_pos": "Purples",
        "pred": "Greens",
    },
):
    f, axes = plt.subplots(1, 5, dpi=dpi)

    # FPR (predicted mask on cannot flood)
    axes[0].imshow(img)
    fp_map_plt = axes[0].imshow(
        maps_dict["fp"], vmin=vmin, vmax=vmax, cmap=cmap["fp"], alpha=alpha
    )
    axes[0].axis("off")
    axes[0].set_title("FPR: {:.4f}".format(metrics_dict["fpr"]), fontsize=fontsize)

    # FNR (missed mask on must flood)
    axes[1].imshow(img)
    fn_map_plt = axes[1].imshow(
        maps_dict["fn"], vmin=vmin, vmax=vmax, cmap=cmap["fn"], alpha=alpha
    )
    axes[1].axis("off")
    axes[1].set_title("FNR: {:.4f}".format(metrics_dict["fnr"]), fontsize=fontsize)

    # May flood
    axes[2].imshow(img)
    if edge_coherence != -1:
        title = "MNR: {:.2f} | MPR: {:.2f}\nEdge coh.: {:.4f}".format(
            metrics_dict["mnr"], metrics_dict["mpr"], edge_coherence
        )
    else:
        title = "MNR: {:.2f} | MPR: {:.2f}".format(mnr, mpr)
    #         alpha_here = alpha / 2.
    may_neg_map_plt = axes[2].imshow(
        maps_dict["may_neg"], vmin=vmin, vmax=vmax, cmap=cmap["may_neg"], alpha=alpha
    )
    may_pos_map_plt = axes[2].imshow(
        maps_dict["may_pos"], vmin=vmin, vmax=vmax, cmap=cmap["may_pos"], alpha=alpha
    )
    axes[2].set_title(title, fontsize=fontsize)
    axes[2].axis("off")

    # Prediction
    axes[3].imshow(img)
    pred_mask = axes[3].imshow(
        pred, vmin=vmin, vmax=vmax, cmap=cmap["pred"], alpha=alpha
    )
    axes[3].set_title("Predicted mask", fontsize=fontsize)
    axes[3].axis("off")

    # Labels
    axes[4].imshow(img)
    label_mask = axes[4].imshow(label, alpha=alpha)
    axes[4].set_title("Labels", fontsize=fontsize)
    axes[4].axis("off")

    f.savefig(
        output_filename,
        dpi=f.dpi,
        bbox_inches="tight",
        facecolor="white",
        transparent=False,
    )
    plt.close(f)


def load_ground(ground_output_path, ref_image_path):
    gop = Path(ground_output_path)
    rip = Path(ref_image_path)

    ground_paths = list((gop / "eval-metrics" / "pred").glob(f"{rip.stem}.jpg")) + list(
        (gop / "eval-metrics" / "pred").glob(f"{rip.stem}.png")
    )
    if len(ground_paths) == 0:
        raise ValueError(
            f"Could not find a ground match in {str(gop)} for image {str(rip)}"
        )
    elif len(ground_paths) > 1:
        raise ValueError(
            f"Found more than 1 ground match in {str(gop)} for image {str(rip)}:"
            + f" {list(map(str, ground_paths))}"
        )
    ground_path = ground_paths[0]
    _, ground = crop_and_resize(rip, ground_path)
    if ground.ndim == 3:
        ground = ground[:, :, 0]
    ground = (ground > 0).astype(np.float32)
    return torch.from_numpy(ground).unsqueeze(0).unsqueeze(0).cuda()


def get_inferences(
    image_arrays, model_path, image_paths, paint=False, bin_value=0.5, verbose=0
):
    """
    Obtains the mask predictions of a model for a set of images

    Parameters
    ----------
    image_arrays : array-like
        A list of (1, CH, H, W) images

    image_paths: list(Path)
        A list of paths for images, in the same order as image_arrays

    model_path : str
        The path to a pre-trained model

    Returns
    -------
    masks : list
        A list of (H, W) predicted masks
    """
    device = torch.device("cuda:0")
    torch.set_grad_enabled(False)
    to_tensor = ToTensor()

    is_ground = "ground" in Path(model_path).name
    is_instagan = "instagan" in Path(model_path).name

    if is_ground or is_instagan:
        # we just care about he painter here
        ground_path = model_path
        model_path = (
            f"/miniscratch/_groups/{GROUP}/experiments/runs/ablation-v1/out--38858350"
        )

    xs = [to_tensor(array).unsqueeze(0) for array in image_arrays]
    xs = [x.to(torch.float32).to(device) for x in xs]
    xs = [(x - 0.5) * 2 for x in xs]
    trainer = Trainer.resume_from_path(
        model_path, inference=True, new_exp=None, device=device
    )
    masks = []
    painted = []
    for idx, x in enumerate(xs):
        if verbose > 0:
            print(idx, "/", len(xs), end="\r")

        if not is_ground and not is_instagan:
            m = trainer.G.mask(x=x)
        else:
            m = load_ground(ground_path, image_paths[idx])

        masks.append(m.squeeze().cpu())
        if paint:
            p = trainer.G.paint(m > bin_value, x)
            painted.append(p.squeeze().cpu())
    return masks, painted


if __name__ == "__main__":
    # -----------------------------
    # -----  Parse arguments  -----
    # -----------------------------
    args = parsed_args()
    print("Args:\n" + "\n".join([f"    {k:20}: {v}" for k, v in vars(args).items()]))

    # Determine output dir
    try:
        tmp_dir = Path(os.environ["SLURM_TMPDIR"])
    except Exception as e:
        print(e)
        tmp_dir = Path(input("Enter tmp output directory: ")).resolve()

    plot_dir = tmp_dir / "plots"
    plot_dir.mkdir(parents=True, exist_ok=True)

    # Build paths to data
    imgs_paths = sorted(
        find_images(args.images_dir, recursive=False), key=lambda x: x.name
    )
    labels_paths = sorted(
        find_images(args.labels_dir, recursive=False),
        key=lambda x: x.name.replace("_labeled.", "."),
    )
    if args.max_files > 0:
        imgs_paths = imgs_paths[: args.max_files]
        labels_paths = labels_paths[: args.max_files]

    print(f"Loading {len(imgs_paths)} images and labels...")

    # Pre-process images: resize + crop
    # TODO: ? make cropping more flexible, not only central
    if not args.prepare_torch:
        ims_labs = [crop_and_resize(i, l) for i, l in zip(imgs_paths, labels_paths)]
        imgs = [d[0] for d in ims_labs]
        labels = [d[1] for d in ims_labs]
    else:
        prepare = PrepareTest()
        imgs = prepare(imgs_paths, normalize=False, rescale=False)
        labels = prepare(labels_paths, normalize=False, rescale=False)

        imgs = [i.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8) for i in imgs]
        labels = [
            l.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8) for l in labels
        ]
    imgs = [rgba2rgb(img) if img.shape[-1] == 4 else img for img in imgs]
    print(" Done.")

    # Encode labels
    print("Encode labels...", end="", flush=True)
    # HW label
    labels = [np.squeeze(encode_mask_label(label, "flood")) for label in labels]
    print("Done.")

    if args.yaml:
        y_path = Path(args.yaml)
        assert y_path.exists()
        assert y_path.suffix in {".yaml", ".yml"}
        with y_path.open("r") as f:
            data = yaml.safe_load(f)
        assert "models" in data

        evaluations = [m for m in data["models"]]
    else:
        evaluations = [args.model]

    for e, eval_path in enumerate(evaluations):
        print("\n>>>>> Evaluation", e, ":", eval_path)
        print("=" * 50)
        print("=" * 50)

        model_metrics_path = Path(eval_path) / "eval-metrics"
        model_metrics_path.mkdir(exist_ok=True)
        if args.load_metrics:
            f_csv = model_metrics_path / "eval_masker.csv"
            pred_out = model_metrics_path / "pred"
            if f_csv.exists() and pred_out.exists():
                print("Skipping model because pre-computed metrics exist")
                continue

        # Initialize New Comet Experiment
        exp = Experiment(
            project_name="climategan-masker-metrics", display_summary_level=0
        )

        # Obtain mask predictions
        # TODO: remove (debug)
        print("Obtain mask predictions", end="", flush=True)

        preds, painted = get_inferences(
            imgs,
            eval_path,
            imgs_paths,
            paint=not args.no_paint,
            bin_value=args.bin_value,
            verbose=1,
        )
        preds = [pred.numpy() for pred in preds]
        print(" Done.")

        if args.bin_value > 0:
            preds = [pred > args.bin_value for pred in preds]

        # Compute metrics
        df = pd.DataFrame(
            columns=[
                "tpr",
                "tpt",
                "tnr",
                "tnt",
                "fpr",
                "fpt",
                "fnr",
                "fnt",
                "mnr",
                "mpr",
                "accuracy",
                "error",
                "precision",
                "f05",
                "accuracy_must_may",
                "edge_coherence",
                "filename",
            ]
        )

        print("Compute metrics and plot images")
        for idx, (img, label, pred) in enumerate(zip(*(imgs, labels, preds))):
            print(idx, "/", len(imgs), end="\r")

            # Basic classification metrics
            metrics_dict, maps_dict = masker_classification_metrics(
                pred, label, labels_dict={"cannot": 0, "must": 1, "may": 2}
            )

            # Edges coherence
            edge_coherence, pred_edge, label_edge = edges_coherence_std_min(pred, label)

            series_dict = {
                "tpr": metrics_dict["tpr"],
                "tpt": metrics_dict["tpt"],
                "tnr": metrics_dict["tnr"],
                "tnt": metrics_dict["tnt"],
                "fpr": metrics_dict["fpr"],
                "fpt": metrics_dict["fpt"],
                "fnr": metrics_dict["fnr"],
                "fnt": metrics_dict["fnt"],
                "mnr": metrics_dict["mnr"],
                "mpr": metrics_dict["mpr"],
                "accuracy": metrics_dict["accuracy"],
                "error": metrics_dict["error"],
                "precision": metrics_dict["precision"],
                "f05": metrics_dict["f05"],
                "accuracy_must_may": metrics_dict["accuracy_must_may"],
                "edge_coherence": edge_coherence,
                "filename": str(imgs_paths[idx].name),
            }
            df.loc[idx] = pd.Series(series_dict)

            for k, v in series_dict.items():
                if k == "filename":
                    continue
                exp.log_metric(f"img_{k}", v, step=idx)

            # Confusion matrix
            confmat, _ = get_confusion_matrix(
                metrics_dict["tpr"],
                metrics_dict["tnr"],
                metrics_dict["fpr"],
                metrics_dict["fnr"],
                metrics_dict["mnr"],
                metrics_dict["mpr"],
            )
            confmat = np.around(confmat, decimals=3)
            exp.log_confusion_matrix(
                file_name=imgs_paths[idx].name + ".json",
                title=imgs_paths[idx].name,
                matrix=confmat,
                labels=["Cannot", "Must", "May"],
                row_label="Predicted",
                column_label="Ground truth",
            )

            if args.plot:
                # Plot prediction images
                fig_filename = plot_dir / imgs_paths[idx].name
                plot_images(
                    fig_filename,
                    img,
                    label,
                    pred,
                    metrics_dict,
                    maps_dict,
                    edge_coherence,
                    pred_edge,
                    label_edge,
                )
                exp.log_image(fig_filename)
            if not args.no_paint:
                masked = img * (1 - pred[..., None])
                flooded = img_as_ubyte(
                    (painted[idx].permute(1, 2, 0).cpu().numpy() + 1) / 2
                )
                combined = np.concatenate([img, masked, flooded], 1)
                exp.log_image(combined, imgs_paths[idx].name)

            if args.write_metrics:
                pred_out = model_metrics_path / "pred"
                pred_out.mkdir(exist_ok=True)
                imsave(
                    pred_out / f"{imgs_paths[idx].stem}_pred.png", pred.astype(np.uint8)
                )
                for k, v in maps_dict.items():
                    metric_out = model_metrics_path / k
                    metric_out.mkdir(exist_ok=True)
                    imsave(
                        metric_out / f"{imgs_paths[idx].stem}_{k}.png",
                        v.astype(np.uint8),
                    )

            # --------------------------------
            # -----  END OF IMAGES LOOP  -----
            # --------------------------------

        if args.write_metrics:
            print(f"Writing metrics in {str(model_metrics_path)}")
            f_csv = model_metrics_path / "eval_masker.csv"
            df.to_csv(f_csv, index_label="idx")

        print(" Done.")
        # Summary statistics
        means = df.mean(axis=0)
        confmat_mean, confmat_std = get_confusion_matrix(
            df.tpr, df.tnr, df.fpr, df.fnr, df.mpr, df.mnr
        )
        confmat_mean = np.around(confmat_mean, decimals=3)
        confmat_std = np.around(confmat_std, decimals=3)

        # Log to comet
        exp.log_confusion_matrix(
            file_name="confusion_matrix_mean.json",
            title="confusion_matrix_mean.json",
            matrix=confmat_mean,
            labels=["Cannot", "Must", "May"],
            row_label="Predicted",
            column_label="Ground truth",
        )
        exp.log_confusion_matrix(
            file_name="confusion_matrix_std.json",
            title="confusion_matrix_std.json",
            matrix=confmat_std,
            labels=["Cannot", "Must", "May"],
            row_label="Predicted",
            column_label="Ground truth",
        )
        exp.log_metrics(dict(means))
        exp.log_table("metrics.csv", df)
        exp.log_html(df.to_html(col_space="80px"))
        exp.log_parameters(vars(args))
        exp.log_parameter("eval_path", str(eval_path))
        exp.add_tag("eval_masker")
        if args.tags:
            exp.add_tags(args.tags)
        exp.log_parameter("model_id", Path(eval_path).name)

        # Close comet
        exp.end()

        # --------------------------------
        # -----  END OF MODElS LOOP  -----
        # --------------------------------

    # Compare models
    if (args.load_metrics or args.write_metrics) and len(evaluations) > 1:
        print(
            "Plots for comparing the input models will be created and logged to comet"
        )

        # Initialize New Comet Experiment
        exp = Experiment(
            project_name="climategan-masker-metrics", display_summary_level=0
        )
        if args.tags:
            exp.add_tags(args.tags)

        # Build DataFrame with all models
        print("Building pandas DataFrame...")
        models_df = {}
        for (m, model_path) in enumerate(evaluations):
            model_path = Path(model_path)
            with open(model_path / "opts.yaml", "r") as f:
                opt = yaml.safe_load(f)
            model_feats = ", ".join(
                [
                    t
                    for t in sorted(opt["comet"]["tags"])
                    if "branch" not in t and "ablation" not in t and "trash" not in t
                ]
            )
            model_id = f"{model_path.parent.name[-2:]}/{model_path.name}"
            df_m = pd.read_csv(
                model_path / "eval-metrics" / "eval_masker.csv", index_col=False
            )
            df_m["model"] = [model_id] * len(df_m)
            df_m["model_idx"] = [m] * len(df_m)
            df_m["model_feats"] = [model_feats] * len(df_m)
            models_df.update({model_id: df_m})
        df = pd.concat(list(models_df.values()), ignore_index=True)
        df["model_img_idx"] = df.model.astype(str) + "-" + df.idx.astype(str)
        df.rename(columns={"idx": "img_idx"}, inplace=True)
        dict_models_labels = {
            k: f"{v['model_idx'][0]}: {v['model_feats'][0]}"
            for k, v in models_df.items()
        }
        print("Done")

        if args.output_csv:
            print(f"Writing DataFrame to {args.output_csv}")
            df.to_csv(args.output_csv, index_label="model_img_idx")

        # Determine images with low metrics in any model
        print("Constructing filter based on metrics thresholds...")
        idx_not_good_in_any = []
        for idx in df.img_idx.unique():
            df_th = df.loc[
                (
                    # TODO: rethink thresholds
                    (df.tpr <= dict_metrics["threshold"]["tpr"])
                    | (df.fpr >= dict_metrics["threshold"]["fpr"])
                    | (df.edge_coherence >= dict_metrics["threshold"]["edge_coherence"])
                )
                & ((df.img_idx == idx) & (df.model.isin(df.model.unique())))
            ]
            if len(df_th) > 0:
                idx_not_good_in_any.append(idx)
        filters = {"all": df.img_idx.unique(), "not_good_in_any": idx_not_good_in_any}
        print("Done")

        # Boxplots of metrics
        print("Plotting boxplots of metrics...")
        for k, f in filters.items():
            print(f"\tDistribution of [{k}] images...")
            for metric in dict_metrics["names"].keys():
                fig_filename = plot_dir / f"boxplot_{metric}_{k}.png"
                if metric in ["mnr", "mpr", "accuracy_must_may"]:
                    boxplot_metric(
                        fig_filename,
                        df.loc[df.img_idx.isin(f)],
                        metric=metric,
                        dict_metrics=dict_metrics["names"],
                        do_stripplot=True,
                        dict_models=dict_models_labels,
                        order=list(df.model.unique()),
                    )
                else:
                    boxplot_metric(
                        fig_filename,
                        df.loc[df.img_idx.isin(f)],
                        metric=metric,
                        dict_metrics=dict_metrics["names"],
                        dict_models=dict_models_labels,
                        fliersize=1.0,
                        order=list(df.model.unique()),
                    )
                exp.log_image(fig_filename)
        print("Done")

        # Cluster Maps
        print("Plotting clustermaps...")
        for k, f in filters.items():
            print(f"\tDistribution of [{k}] images...")
            for metric in dict_metrics["names"].keys():
                fig_filename = plot_dir / f"clustermap_{metric}_{k}.png"
                df_mf = df.loc[df.img_idx.isin(f)].pivot("img_idx", "model", metric)
                clustermap_metric(
                    output_filename=fig_filename,
                    df=df_mf,
                    metric=metric,
                    dict_metrics=dict_metrics["names"],
                    method="average",
                    cluster_metric="euclidean",
                    dict_models=dict_models_labels,
                    row_cluster=False,
                )
                exp.log_image(fig_filename)
        print("Done")

        # Close comet
        exp.end()
