import csv
import pathlib

import click
import matplotlib.pyplot as plt
import numpy as np
import tensorboard as tb
from tensorflow.python.summary.summary_iterator import summary_iterator

from microsoft_nlp import paths
from microsoft_nlp.utils import Location


@click.group()
def cli(*args, **kwargs):
    pass


@cli.command()
@click.option(
    "--rows",
    type=str,
    default=None,
    help="Particular rows in the experiment journal to plot, e.g. '1,5,24'",
    show_default=True,
)
@click.option(
    "--parameter_ranges",
    type=str,
    default="None-None",
    help="Range(s) of parameters to plot, e.g. '1000-2000,4000-5000'",
    show_default=True,
)
@click.option(
    "--model_types",
    type=str,
    default=None,
    help="Types of models to plot, e.g. 'GPT1,LMU2Dli'",
    show_default=True,
)
@click.option(
    "--tokens_per_epoch",
    type=int,
    default=200 * 1024 * 512,  # ~105B
    help="Number of tokens to process per training epoch",
    show_default=True,
)
@click.option(
    "--max_loss",
    type=float,
    default=9.0,
    help="Results with a final epoch loss above this number will not be shown",
    show_default=True,
)
@click.option(
    "--sync/--no_sync",
    default=True,
    help="Whether to sync results from remote server(s)",
    show_default=True,
)
@click.option(
    "--verbosity",
    type=int,
    default=1,
    help="How verbose to be",
    show_default=True,
)
def figure2(
    rows, parameter_ranges, model_types, tokens_per_epoch, max_loss, sync, verbosity
):
    def parse_endpoint(x):
        x = x.strip().lower()
        return None if x == "none" else int(x)

    parsed_parameter_ranges = []
    for s in parameter_ranges.split(","):
        a, b = s.split("-")
        parsed_parameter_ranges.append((parse_endpoint(a), parse_endpoint(b)))

    model_types = None if model_types is None else model_types.split(",")
    rows = None if rows is None else [int(x) for x in rows.split(",")]

    plot_figure2(
        row_indices=rows,
        param_ranges=parsed_parameter_ranges,
        model_types=model_types,
        tokens_per_epoch=tokens_per_epoch,
        max_loss=max_loss,
        sync=sync,
        verbosity=verbosity,
        print_fn=click.echo,
    )


def plot_figure2(
    row_indices=None,
    param_ranges=((None, None),),
    model_types=None,
    tokens_per_epoch=200 * 1024 * 512,
    max_loss=9.0,
    sync=True,
    print_fn=print,
    verbosity=1,
):
    remote_location = Location("abrghost", "/home/transfer/rnd")

    # figure2.npz is available in 'reverse-engineering-figure2-left.zip' in Google Drive
    data = np.load("figure2.npz")
    colors = data["colors"]
    params = data["params"]
    x = data["x"]
    curves = data["curves"]

    # experiment journal must be downloaded from Google Spreadsheet as CSV
    experiment_csv = "Experiment Journal - Sheet1.csv"
    with open(experiment_csv, "r") as csvfile:
        reader = csv.reader(csvfile)
        rows = list(reader)

    index_row, rows = rows[0], rows[1:]

    row_info = {}
    for i, row in enumerate(rows):
        row_i = i + 2

        filename = row[index_row.index("Filename")]
        filename = filename.split("\n")[-1]  # if multiple lines, take last
        filename = pathlib.Path(filename)
        if filename.suffix == ".hdf5":
            filename = filename.with_suffix("")
        filename = str(filename)

        ne_params = row[index_row.index("Parameters (Non - Embedding)")]
        ne_params = ne_params.replace(",", "").strip()
        ne_params = int(ne_params) if len(ne_params) > 0 else -1
        model_type = row[index_row.index("Model")].lower()

        if row_indices is not None and row_i not in row_indices:
            if verbosity >= 2:
                print_fn(f"Row {row_i}: Not in specified row indices")
            continue

        if not any(is_in_range(ne_params, prange) for prange in param_ranges):
            if verbosity >= 2:
                print_fn(
                    f"Row {row_i}: Nonembedding params {ne_params} not in any range"
                )
            continue

        if model_types is not None and not any(
            model_type == mt.lower() for mt in model_types
        ):
            if verbosity >= 2:
                print_fn(f"Row {row_i}: Model type '{model_type}' not in model types")
            continue

        if filename.lower() in ("", "na", "."):
            if verbosity >= 2:
                print_fn(f"Row {row_i}: Skipping empty filename '{filename}'")
            continue

        if len(filename) < 8:
            if verbosity >= 2:
                print_fn(f"Row {row_i}: Skipping too short filename '{filename}'")
            continue

        if sync:
            # try to sync from remote location
            if filename not in remote_location.contents:
                if verbosity >= 2:
                    print_fn(
                        f"Row {row_i}: Could not find '{filename}' in remote location"
                    )
                continue

            filepath = remote_location.sync_remote_subdir(
                filename, target_dir=paths.logs_sync
            )
        else:
            filepath = paths.logs_sync / filename

        if not filepath.exists():
            if sync:
                if verbosity >= 1:
                    print_fn(
                        f"Row {row_i}: Should be synced, but does not exist: '{filepath}'"
                    )
            else:
                if verbosity >= 2:
                    print_fn(f"Row {row_i}: Log does not exist locally: '{filepath}'")

            continue

        try:
            epoch_loss = epoch_loss_from_log_dir(filepath)
        except Exception as e:
            if verbosity >= 1:
                print_fn(f"Row {row_i}: Could not get epoch loss: {e}")
            continue

        if len(epoch_loss) == 0:
            if verbosity >= 2:
                print_fn(f"Row {row_i}: Epoch loss is empty")
            continue

        final_epoch, final_loss = epoch_loss[-1]
        if final_loss >= max_loss:
            if verbosity >= 2:
                print_fn(
                    f"Row {row_i}: Final epoch loss {final_loss} "
                    f"exceeds max loss {max_loss}"
                )
            continue

        name = f"row {row_i}: {filename}"
        row_info[name] = {
            "i": row_i,
            "label": f"{name} ({ne_params:0.1e} params, "
            f"{final_loss:0.3f} loss at epoch {final_epoch + 1})",
            "model_type": model_type,
            "nonembedding_params": ne_params,
            "epoch_loss": epoch_loss,
        }

    # --- plot the figures
    for p_range in param_ranges:
        plt.figure(figsize=(12, 8))

        # --- plot original figure 2 transformer results
        inds = set(
            k for k, params in enumerate(params) if is_in_range(10 ** params, p_range)
        )
        inds = inds.difference([0, 4])

        for k, (color, curve) in enumerate(zip(colors, curves)):
            if inds is not None and k not in inds:
                continue

            label = f"fig2 transformer ({10 ** params[k]:0.1e} params)"
            plt.plot(x, curve, label=label)
            # plt.plot(x, curve, color=tuple(np.asarray(color) / 255), label=label)

        # --- plot our results overtop
        for name, info in row_info.items():
            if not is_in_range(info["nonembedding_params"], p_range):
                continue

            epochs, loss = info["epoch_loss"].T
            tokens = (epochs + 1) * tokens_per_epoch
            plt.plot(np.log10(tokens), loss, label=info["label"])

        p_min, p_max = p_range
        plt.xlim([7, 11])
        plt.xlabel("tokens processed (log10)")
        plt.ylabel("train loss")
        plt.legend(loc=1)

        model_str = "" if model_types is None else ",".join(model_types)
        row_str = "" if row_indices is None else ",".join(map(str, row_indices))
        title_model_str = f"models {model_str}" if model_str else "all models"
        title_row_str = f"(rows {row_str}) " if row_str else ""
        title = f"{title_model_str} {title_row_str}"
        if p_min is not None or p_max is not None:
            title = f"{title} with {p_min}-{p_max} relevant params"

        plt.title(title)
        plt.tight_layout()

        name_model_str = f"_models={model_str}" if model_str else ""
        name_row_str = f"_rows={row_str}" if row_str else ""
        figname = paths.plots / (
            f"figure2_params={p_min}-{p_max}{name_model_str}{name_row_str}.pdf"
        )
        plt.savefig(str(figname))
        if verbosity >= 1:
            print_fn(f"Saved figure: '{figname}'")


def epoch_loss_from_experiment_id(experiment_id):
    experiment = tb.data.experimental.ExperimentFromDev(experiment_id)
    df = experiment.get_scalars()

    df = df.loc[df["run"] == "train"]
    df = df.loc[df["tag"] == "epoch_loss"]
    return np.array(df[["step", "value"]])


def read_summary_file(summary_file, tag="epoch_loss"):
    step_value = []
    for e in summary_iterator(str(summary_file)):
        for v in e.summary.value:
            if v.tag == tag:
                step_value.append((e.step, v.simple_value))

    return np.array(step_value) if len(step_value) > 0 else np.zeros((0, 2))


def epoch_loss_from_log_dir(log_dir):
    log_dir = pathlib.Path(log_dir)

    assert log_dir.exists(), f"log_dir does not exist: '{log_dir}'"

    train_dir = log_dir / "train"
    assert train_dir.exists(), f"train dir does not exist: '{train_dir}'"

    summary_files = sorted(list(train_dir.glob("events.out.tfevents.*.v2")))
    assert len(summary_files) > 0, "No summary files found"

    if len(summary_files) > 1:
        print("Multiple summary files exist, only using the first")

    epoch_loss = read_summary_file(summary_files[0], tag="epoch_loss")
    return epoch_loss


def is_in_range(x, x_range):
    x_min, x_max = x_range
    return (x_min is None or x >= x_min) and (x_max is None or x <= x_max)


if __name__ == "__main__":
    cli()
