"""
Usage:
    python3 -m dp_join.experiment.joint_counts [out_dir]

Experimentally evaluates the use of joinable private sketches to estimate
joint distributions with a range of parameters, using the Adult dataset.

Writes output to the directory at out_dir as an index.html file containing a
table and referring to some generated plots. Also saves a file "results.csv"
containing the values that were plotted. If out_dir isn't specified, a name
will be invented.

If out_dir already exists contains results.csv, this script will redraw the
plots without rerunning the experiments. The
intended use of this is to be able to modify the plotting code and quickly
regenerate the plots without rerunning time-consuming experiments. This only
works if experiment code hasn't changed too much since the CSV files. In
particular, the repetitions_per_experiment variable should not have changed,
otherwise the y axis label will be inaccurate.
"""

from ..adult import adult_column_descriptions_by_name, read_as_columns
from ..hash import salted_sha256
from ..sketch.one_hot import OneHotSketcher
from ..timing import fmttime

from collections import namedtuple
import html
from matplotlib import pyplot
import numpy
import numpy.random
import os
import pandas
import shutil
import sys
import time

# Constants
repetitions_per_experiment = 10
nonprivate_columns = [
  "workclass", "education", "marital-status", "occupation", "race", "sex",
  "native-country",
]
plot_by_d_filename = "LaplaceErrorBySketchDimension.png"
plot_by_eps_filename = "LaplaceErrorByEpsilon.png"
results_csv_filename = "results.csv"
report_header = f"""\
<!DOCTYPE html>
<html>
<head>
\t<link href="style.css" rel="stylesheet">
</head>
<body>
<p><a href="{results_csv_filename}">Results (csv)</a></p>
<h2>Plots</h2>
<h3>Varying epsilon</h3>
<p><img src="{plot_by_eps_filename}"></p>
<h3>Varying sketch dimension</h3>
<p><img src="{plot_by_d_filename}"></p>
"""
report_footer = "</body>\n</html>\n"

ExperimentParameters = namedtuple(
    "ExperimentParameters",
    ( # differential privacy parameter
      "epsilon",
      "sketch_dimension",
      # private_column and nonprivate_column are the names of the columns to
      # join. A sketch of private_column is generated, and joined with
      # nonprivate_column. Both colums should have categorical values.
      "private_column",
      "nonprivate_column",
    ))

def joint_counts(row_data, num_rows, column_data, num_columns):
    """
    row_data should be a sequence of integers from 0 to num_rows-1, and
    similar for column_data and num_columns. Produces a matrix (numpy array)
    counting how many times each combination appears: that is, the (i,j)
    element of the matrix counts for how many indices k row_data[k] == i and
    column_data[k] == j. The matrix will have size num_rows by num_columns.

    For example, if row_data is (0,0,0,1) and column_data is (2,1,2,0), then
    the combinations (0,1) and (1,0) each appear once, and (0,2) appears twice,
    so the matrix is:
        0 1 2
        1 0 0
    >>> joint_counts((0, 0, 0, 1), 2, (2, 1, 2, 0), 3)
    array([[0, 1, 2],
           [1, 0, 0]])
    """
    counts = numpy.zeros((num_rows, num_columns), dtype=numpy.int64)
    for row, col in zip(row_data, column_data):
        counts[row, col] += 1
    return counts

def count_matrix_to_html(out, count_matrix, row_titles, column_titles):
    out.write('<table class="count_table">\n<tr><td></td>')
    for column_title in column_titles:
        out.write(f"<td>{html.escape(column_title)}</td>")
    out.write("</tr>\n")
    for row, row_title in zip(count_matrix, row_titles):
        count_row_to_html_row(out, row, row_title)
    out.write("</table>\n")

def count_row_to_html_row(out, count_row, row_title):
    out.write(f"<tr><td>{row_title}</td>")
    for count in count_row:
        out.write(f'<td class="count_element">{count}</td>')
    out.write("</tr>\n")

def parameters_to_html(parameters):
    return html.escape(str(parameters))

def do_experiment(data, parameters, rng, html_table_out = None):
    """
    Does a single experiment according to parameters, which must be an
    instance of ExperimentParameters. Returns the average error.

    If html_table_out is given, it should be a file, and a table of results
    will be written to it.

    This code has only been tested when parameters.private_column has two
    classes, i.e. the column description's num_classes() method returns 2.
    """
    def write_table_html(text):
        if html_table_out:
            html_table_out.write(text)

    write_table_html("<p>Experiment with parameters: ")
    write_table_html(f"{parameters_to_html(parameters)}</p>\n")

    private_column_description = \
        adult_column_descriptions_by_name[parameters.private_column]
    nonprivate_column_description = \
        adult_column_descriptions_by_name[parameters.nonprivate_column]

    hash_salt = str(rng.integers(2**31))
    sketcher = OneHotSketcher(
        hash_function = salted_sha256(hash_salt),
        rng = rng,
    )

    private_data = data[parameters.private_column]
    nonprivate_data = data[parameters.nonprivate_column]
    num_rows, = private_data.shape
    assert nonprivate_data.shape == (num_rows,)

    # Invent some identities to join on.
    identities = tuple(f"Identity #{i}" for i in range(num_rows))

    private_sketch = sketcher.sketch_values(
        epsilon = parameters.epsilon,
        identities = identities,
        num_categories = private_column_description.num_classes(),
        values = private_data,
        num_buckets = parameters.sketch_dimension,
    )

    true_counts = joint_counts(
        row_data = private_data,
        num_rows = private_column_description.num_classes(),
        column_data = nonprivate_data,
        num_columns = nonprivate_column_description.num_classes(),
    )
    true_fraction_0 = true_counts[0, :] / numpy.sum(true_counts, 0)
    if html_table_out:
        write_table_html("<p>True counts:</p>\n")
        table_contents = numpy.concatenate((true_counts, (true_fraction_0,)))
        private_class_names = private_column_description.class_names()
        table_row_titles = (
            private_class_names +
            (f"True fraction {private_class_names[0]}",))
        count_matrix_to_html(html_table_out, table_contents, table_row_titles,
                             nonprivate_column_description.class_names())

    estimated_counts = numpy.zeros(
        (private_column_description.num_classes(),
         nonprivate_column_description.num_classes()),
        dtype=numpy.int64)
    for identity, nonprivate_value in zip(identities, nonprivate_data):
        for private_value in range(private_column_description.num_classes()):
            estimated_counts[private_value, nonprivate_value] += \
                sketcher.estimate_membership(
                    private_sketch, identity, private_value)
    estimated_counts[estimated_counts < 0] = 0  # Change negative values to 0.
    estimated_counts = estimated_counts+1  # Smoothing to avoid 0/0 nan error.

    estimated_fraction_0 = (
        estimated_counts[0, :] / numpy.sum(estimated_counts, 0))
    if html_table_out:
        write_table_html("<p>Estimated counts:</p>\n")
        table_contents = numpy.concatenate(
            (estimated_counts, (estimated_fraction_0,)))
        private_class_names =  private_column_description.class_names()
        table_row_titles = (
            private_class_names +
            (f"Estimated fraction {private_class_names[0]}",))
        count_matrix_to_html(html_table_out, table_contents, table_row_titles,
                             nonprivate_column_description.class_names())

    write_table_html("<p>Average Error:</p>\n")
    avg_error = numpy.average(abs(true_fraction_0 - estimated_fraction_0))
    write_table_html("<p>" + str(avg_error) + "</p>\n")

    return avg_error

def make_plots(out_dir, results_csv_path):
    individual_results = pandas.read_csv(results_csv_path)
    aggregated = (
        individual_results
            .groupby(["NonPrivate", "Sketch Dimension", "Epsilon"])
            .mean()
            # Call .reset_index() to get a dataframe with the same four columns
            # as individual_results.
            .reset_index())

    # Varying epsilon

    sketch500 = aggregated[aggregated["Sketch Dimension"]==500_000]
    for column in nonprivate_columns:
        to_plot = sketch500[sketch500["NonPrivate"]==column]
        pyplot.plot(to_plot["Epsilon"], to_plot["AvgError"], label=column)

    pyplot.xlabel("$\\epsilon$")
    pyplot.ylabel(f"Average Error over {repetitions_per_experiment} runs")
    pyplot.title("Joint Distribution: Avg Error as a function of $\\epsilon$"
                 " (dimension: 500K)")
    pyplot.xscale("log")
    pyplot.legend()
    pyplot.savefig(f"{out_dir}/{plot_by_eps_filename}")

    # Varying number of identity buckets

    pyplot.figure()
    eps1 = aggregated[aggregated["Epsilon"]==1.0]
    for column in nonprivate_columns:
        to_plot = eps1[eps1["NonPrivate"]==column]
        pyplot.plot(to_plot["Sketch Dimension"], to_plot["AvgError"],
                    label=column)

    pyplot.xlabel("Sketch Dimension")
    pyplot.ylabel(f"Average Error over {repetitions_per_experiment} runs")
    pyplot.xscale("log")
    pyplot.title("Joint Distribution: Avg Error as a function of Sketch"
                 " Dimension ($\epsilon = 1$)")
    pyplot.legend()
    pyplot.savefig(f"{out_dir}/{plot_by_d_filename}")

def do_experiments(html_out, results_csv_path):
    """
    Performs all the experiments and writes a report to html_out, which should
    be a file object. The report includes images which will be generated by the
    make_plots function. Also saves the individual experiment results to a CSV
    file.
    """
    html_out.write(report_header)

    rng = numpy.random.default_rng()
    with open("datasets/adult/adult.data") as adult_file:
        adult_columns = read_as_columns(adult_file)

    # First, do a single experiment and write its results to a table.
    html_out.write("<h2>Table</h2>\n")
    do_experiment(
        data = adult_columns,
        parameters = ExperimentParameters(
            epsilon = 1.0,
            private_column = "income",
            sketch_dimension = 500_000,
            nonprivate_column = "race",
        ),
        rng = rng,
        html_table_out = html_out,
    )

    # Then do a series of experiments and record the results.
    results = []
    eps_d_combinations = (
        tuple((eps, 500_000) for eps in (0.25, 0.5, 0.75, 1.0)) +
        tuple((1.0, dimension)
              for dimension in (250_000, 100_000, 50_000, 10_000)))
    for p in nonprivate_columns:
        for eps, d in eps_d_combinations:
            for i in range(repetitions_per_experiment):
                print(f"{fmttime()} {p}: eps={eps}, d={d}...")
                avg_error = do_experiment(
                    data = adult_columns,
                    parameters = ExperimentParameters(
                        epsilon = eps,
                        private_column = "income",
                        sketch_dimension = d,
                        nonprivate_column = p,
                    ),
                    rng = rng,
                )
                results.append((p, d, eps, avg_error))

    results_column_names = (
        "NonPrivate", "Sketch Dimension", "Epsilon", "AvgError")

    with open(results_csv_path, "x") as results_csv_out:
        results_csv_out.write(",".join(results_column_names))
        results_csv_out.write("\n")
        for row in results:
            results_csv_out.write(",".join(str(x) for x in row))
            results_csv_out.write("\n")

    html_out.write(report_footer)

def main():
    if len(sys.argv) == 1:
        out_dir = f"out/joins_{fmttime()}"
        os.makedirs(out_dir)
    elif len(sys.argv) == 2:
        out_dir = sys.argv[1]
        os.makedirs(out_dir, exist_ok = True)
    else:
        raise RuntimeError("Too many command-line arguments.")

    results_csv_path = f"{out_dir}/{results_csv_filename}"
    print(f"Using {out_dir} for data and report.")

    shutil.copyfile("resources/style.css", f"{out_dir}/style.css")

    if not os.path.exists(results_csv_path):
        index_html_path = f"{out_dir}/index.html"
        with open(index_html_path, "x") as html_out:
            do_experiments(html_out, results_csv_path)
        print(f"Wrote a report to {index_html_path}.")
    make_plots(out_dir, results_csv_path)

if __name__ == "__main__":
    main()
