import json
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from matplotlib import pyplot as plt

MAX_KEY = 3000

def get_stats(l):
    assert len(l) > 0
    mean_dict, stdev_dict, min_dict, max_dict = {}, {}, {}, {}
    for key in l[0].keys():
        if MAX_KEY is not None and int(key) > MAX_KEY:
            continue
        all_nodes = [i[key] for i in l]
        all_nodes = np.array(all_nodes)
        mean = np.mean(all_nodes)
        std = np.std(all_nodes)
        min = np.min(all_nodes)
        max = np.max(all_nodes)
        mean_dict[int(key)] = mean
        stdev_dict[int(key)] = std
        min_dict[int(key)] = min
        max_dict[int(key)] = max
    return mean_dict, stdev_dict, min_dict, max_dict


def plot(means, stdevs, mins, maxs, title, label, loc, xlabel="communication rounds", xscale=None):
    plt.title(title)
    plt.xlabel(xlabel)
    x_axis = np.array(list(means.keys()))
    y_axis = np.array(list(means.values()))
    err = np.array(list(stdevs.values()))
    plt.plot(x_axis, y_axis, label=label)
    plt.fill_between(x_axis, y_axis - err, y_axis + err, alpha=0.4)
    leg = plt.legend(loc=loc)
    for text in leg.get_texts():
        plt.setp(text, fontsize=16)
    if xscale is not None:
        plt.xscale(xscale)



def replace_dict_key(d_org: dict, d_other: dict):
    result = {}
    for x, y in d_org.items():
        result[d_other[x]] = y
    return result


def plot_results(path, centralized, data_machine="machine0", data_node=0):
    folders = os.listdir(path)
    if centralized.lower() in ["true", "1", "t", "y", "yes"]:
        centralized = True
        print("Centralized")
    else:
        centralized = False

    folders.sort()
    print("Reading folders from: ", path)
    print("Folders: ", folders)
    bytes_means, bytes_stdevs = {}, {}
    meta_means, meta_stdevs = {}, {}
    data_means, data_stdevs = {}, {}
    for folder in folders:
        folder_path = Path(os.path.join(path, folder))
        if not folder_path.is_dir() or "weights" == folder_path.name:
            continue
        results = []
        machine_folders = os.listdir(folder_path)
        for machine_folder in machine_folders:
            mf_path = os.path.join(folder_path, machine_folder)
            if not os.path.isdir(mf_path):
                continue
            files = os.listdir(mf_path)
            files = [f for f in files if f.endswith("_results.json")]
            for f in files:
                filepath = os.path.join(mf_path, f)
                with open(filepath, "r") as inf:
                    results.append(json.load(inf))
        if folder.startswith("FL") or folder.startswith("Parameter Server"):
            data_node = -1
        else:
            data_node = 0
        with open(folder_path / data_machine / f"{data_node}_results.json", "r") as f:
            main_data = json.load(f)
        main_data = [main_data]

        # Plotting bytes over time
        plt.figure(10)
        b_means, stdevs, mins, maxs = get_stats([x["total_bytes"] for x in results])
        plot(b_means, stdevs, mins, maxs, "Total Bytes", folder, "lower right")
        df = pd.DataFrame(
            {
                "mean": list(b_means.values()),
                "std": list(stdevs.values()),
                "nr_nodes": [len(results)] * len(b_means),
            },
            list(b_means.keys()),
            columns=["mean", "std", "nr_nodes"],
        )
        df.to_csv(
            os.path.join(path, "total_bytes_" + folder + ".csv"), index_label="rounds"
        )

        # Plot Training loss
        plt.figure(1)
        means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results])
        plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right")

        correct_bytes = [b_means[x] for x in means]

        df = pd.DataFrame(
            {
                "mean": list(means.values()),
                "std": list(stdevs.values()),
                "nr_nodes": [len(results)] * len(means),
                "total_bytes": correct_bytes,
            },
            list(means.keys()),
            columns=["mean", "std", "nr_nodes", "total_bytes"],
        )
        plt.figure(11)
        means = replace_dict_key(means, b_means)
        plot(
            means,
            stdevs,
            mins,
            maxs,
            "Training Loss",
            folder,
            "upper right",
            "Total Bytes per node",
            "log",
        )

        df.to_csv(
            os.path.join(path, "train_loss_" + folder + ".csv"), index_label="rounds"
        )
        # Plot Testing loss
        plt.figure(2)
        if centralized:
            means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in main_data])
        else:
            means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in results])
        plot(means, stdevs, mins, maxs, "Testing Loss", folder, "upper right")
        df = pd.DataFrame(
            {
                "mean": list(means.values()),
                "std": list(stdevs.values()),
                "nr_nodes": [len(results)] * len(means),
                "total_bytes": correct_bytes,
            },
            list(means.keys()),
            columns=["mean", "std", "nr_nodes", "total_bytes"],
        )
        plt.figure(12)
        means = replace_dict_key(means, b_means)
        plot(
            means,
            stdevs,
            mins,
            maxs,
            "Testing Loss",
            folder,
            "upper right",
            "Total Bytes per node",
            "log",
        )

        df.to_csv(
            os.path.join(path, "test_loss_" + folder + ".csv"), index_label="rounds"
        )
        # Plot Testing Accuracy
        plt.figure(3)
        if centralized:
            means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in main_data])
        else:
            means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results])
        plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right")
        df = pd.DataFrame(
            {
                "mean": list(means.values()),
                "std": list(stdevs.values()),
                "nr_nodes": [len(results)] * len(means),
                "total_bytes": correct_bytes,
            },
            list(means.keys()),
            columns=["mean", "std", "nr_nodes", "total_bytes"],
        )
        plt.figure(13)
        means = replace_dict_key(means, b_means)
        plot(
            means,
            stdevs,
            mins,
            maxs,
            "Testing Accuracy",
            folder,
            "lower right",
            "Total Bytes per node",
            "log",
        )
        df.to_csv(
            os.path.join(path, "test_acc_" + folder + ".csv"), index_label="rounds"
        )

        # Collect total_bytes shared
        bytes_list = []
        for x in results:
            max_key = str(max(list(map(int, x["total_bytes"].keys())))) if MAX_KEY is None else str(MAX_KEY)
            bytes_list.append({max_key: x["total_bytes"][max_key]})
        means, stdevs, mins, maxs = get_stats(bytes_list)
        bytes_means[folder] = list(means.values())[0]
        bytes_stdevs[folder] = list(stdevs.values())[0]

        meta_list = []
        for x in results:
            if x["total_meta"]:
                max_key = str(max(list(map(int, x["total_meta"].keys())))) if MAX_KEY is None else str(MAX_KEY)
                meta_list.append({max_key: x["total_meta"][max_key]})
            else:
                meta_list.append({max_key: 0})
        means, stdevs, mins, maxs = get_stats(meta_list)
        meta_means[folder] = list(means.values())[0]
        meta_stdevs[folder] = list(stdevs.values())[0]

        data_list = []
        for x in results:
            max_key = str(max(list(map(int, x["total_data_per_n"].keys())))) if MAX_KEY is None else str(MAX_KEY)
            data_list.append({max_key: x["total_data_per_n"][max_key]})
        means, stdevs, mins, maxs = get_stats(data_list)
        data_means[folder] = list(means.values())[0]
        data_stdevs[folder] = list(stdevs.values())[0]

    plt.figure(10)
    plt.savefig(os.path.join(path, "total_bytes.pdf"), dpi=300)
    plt.figure(11)
    plt.savefig(os.path.join(path, "bytes_train_loss.pdf"), dpi=300)
    plt.figure(12)
    plt.savefig(os.path.join(path, "bytes_test_loss.pdf"), dpi=300)
    plt.figure(13)
    plt.savefig(os.path.join(path, "bytes_test_acc.pdf"), dpi=300)

    plt.figure(1)
    plt.savefig(os.path.join(path, "train_loss.pdf"), dpi=300)
    plt.figure(2)
    plt.savefig(os.path.join(path, "test_loss.pdf"), dpi=300)
    plt.figure(3)
    plt.savefig(os.path.join(path, "test_acc.pdf"), dpi=300)
    # Plot total_bytes
    plt.figure(4)
    plt.title("Data Shared")
    x_pos = np.arange(len(bytes_means.keys()))
    plt.bar(
        x_pos,
        np.array(list(bytes_means.values())) // (1024 * 1024),
        yerr=np.array(list(bytes_stdevs.values())) // (1024 * 1024),
        align="center",
    )
    plt.ylabel("Total data shared in MBs")
    plt.xlabel("Fraction of Model Shared")
    plt.xticks(x_pos, list(bytes_means.keys()))
    plt.savefig(os.path.join(path, "data_shared.pdf"), dpi=300)

    # Plot stacked_bytes
    plt.figure(5)
    plt.title("Data Shared per Neighbor")
    x_pos = np.arange(len(bytes_means.keys()))
    plt.bar(
        x_pos,
        np.array(list(data_means.values())) // (1024 * 1024),
        yerr=np.array(list(data_stdevs.values())) // (1024 * 1024),
        align="center",
        label="Parameters",
    )
    plt.bar(
        x_pos,
        np.array(list(meta_means.values())) // (1024 * 1024),
        bottom=np.array(list(data_means.values())) // (1024 * 1024),
        yerr=np.array(list(meta_stdevs.values())) // (1024 * 1024),
        align="center",
        label="Metadata",
    )
    plt.ylabel("Data shared in MBs")
    plt.xlabel("Fraction of Model Shared")
    plt.xticks(x_pos, list(meta_means.keys()))
    plt.savefig(os.path.join(path, "parameters_metadata.pdf"), dpi=300)


def plot_parameters(path):
    plt.figure(4)
    folders = os.listdir(path)
    for folder in folders:
        folder_path = os.path.join(path, folder)
        if not os.path.isdir(folder_path):
            continue
        files = os.listdir(folder_path)
        files = [f for f in files if f.endswith("_shared_params.json")]
        for f in files:
            filepath = os.path.join(folder_path, f)
            print("Working with ", filepath)
            with open(filepath, "r") as inf:
                loaded_dict = json.load(inf)
                del loaded_dict["order"]
                del loaded_dict["shapes"]
            assert len(loaded_dict["0"]) > 0
            assert "0" in loaded_dict.keys()
            counts = np.zeros(len(loaded_dict["0"]))
            for key in loaded_dict.keys():
                indices = np.array(loaded_dict[key])
                counts = np.pad(
                    counts,
                    max(np.max(indices) - counts.shape[0], 0),
                    "constant",
                    constant_values=0,
                )
                counts[indices] += 1
            plt.plot(np.arange(0, counts.shape[0]), counts, ".")
        print("Saving scatterplot")
        plt.savefig(os.path.join(folder_path, "shared_params.pdf"))


if __name__ == "__main__":
    assert len(sys.argv) == 3
    # The args are:
    # 1: the folder with the data
    # 2: True/False: If True then the evaluation on the test set was centralized
    # for federated learning folder name must start with "FL"!
    plot_results(sys.argv[1], sys.argv[2])
    # plot_parameters(sys.argv[1])
