# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pandas as pd
import torch
from torchvision import datasets
import os
import pickle
import matplotlib.pyplot as plt
import seaborn as sns

run = "run_60"
wf_times_file = f"/workspace/Code/fl_security_vertical/Vertical/vertical_training/workspaces/poc_workspace/server/{run}/app_server/wf_times.pkl"
learner_times_site1_file = f"/workspace/Code/fl_security_vertical/Vertical/vertical_training/workspaces/poc_workspace/site-1/{run}/app_site-1/learner_times.pkl"
learner_times_site2_file = f"/workspace/Code/fl_security_vertical/Vertical/vertical_training/workspaces/poc_workspace/site-2/{run}/app_site-2/learner_times.pkl"


def load_pickle(filename):
    with open(filename, "rb") as f:
        data = pickle.load(f)

    return data


def compute_time_diffs(times, offset=0):
    assert len(times) > 0
    t_diffs = []
    for t0, t1 in zip(times[offset::], times[offset+1::]):
        t_diff = t1 - t0
        assert t_diff > 0
        t_diffs.append(t_diff)
    assert len(t_diffs) == len(times[offset::]) - 1
    return t_diffs


def compute_time_diffs2(times1, times2, offset=0):
    assert len(times1) > 0
    assert len(times2) > 0
    t_diffs = []
    assert len(times1) == len(times2)
    for t0, t1 in zip(times1[offset::], times2[offset::]):
        t_diff = t1 - t0
        assert t_diff > 0
        t_diffs.append(t_diff)
    assert len(t_diffs) == len(times1[offset::])
    return t_diffs


wf_times = load_pickle(wf_times_file)
learner_times_site1 = load_pickle(learner_times_site1_file)
learner_times_site2 = load_pickle(learner_times_site2_file)

"""Parse workflow times"""
# offset=1  # removes first outlier
t_wf_train_steps = compute_time_diffs(wf_times["wf_before_data_step"], offset=1)
t_wf_data_steps = compute_time_diffs2(wf_times["wf_before_data_step"], wf_times["wf_before_label_step"], offset=1)
t_wf_label_steps = compute_time_diffs2(wf_times["wf_before_label_step"][0:-1:], wf_times["wf_before_data_step"][1::], offset=1)

"""Parse learner times"""
t_learner_data_steps = compute_time_diffs2(learner_times_site1["learner_start_data_step"], learner_times_site1["learner_end_data_step"])
t_learner_label_steps = compute_time_diffs2(learner_times_site2["learner_start_label_step"], learner_times_site2["learner_end_label_step"])
t_learner_backward_steps = compute_time_diffs2(learner_times_site1["learner_start_backward_step"], learner_times_site1["learner_end_backward_step"])

"""Parse learner aux handler times"""
t_aux_hdl_learner_data_back_steps = compute_time_diffs2(learner_times_site1["aux_hdl_learner_start_data_train_back_step"], learner_times_site1["aux_hdl_learner_end_data_train_back_step"])
t_aux_hdl_learner_label_steps = compute_time_diffs2(learner_times_site2["aux_hdl_learner_start_label_train_step"], learner_times_site2["aux_hdl_learner_end_label_train_step"])

""" communication """
t_comm_data_steps = compute_time_diffs2(wf_times["wf_before_data_step"], learner_times_site1["learner_start_data_step"], offset=1)
t_comm_label_steps = compute_time_diffs2(wf_times["wf_before_label_step"], learner_times_site2["learner_start_label_step"], offset=1)
t_comm_activ_steps = compute_time_diffs2(learner_times_site1["learner_end_data_step"], wf_times["wf_before_label_step"], offset=1)
t_comm_grad_steps = compute_time_diffs2(learner_times_site2["learner_end_label_step"][0:-1:], wf_times["wf_before_data_step"][1::], offset=1)

"""Plots"""
plt.figure()
plt.subplot(3, 3, 1)
plt.plot(t_wf_train_steps)
plt.title(f"WF Total train mean={np.mean(t_wf_train_steps):.2f}$\pm${np.std(t_wf_train_steps):.2f}")

plt.subplot(3, 3, 2)
plt.plot(t_wf_data_steps)
plt.title(f"WF data mean={np.mean(t_wf_data_steps):.2f}$\pm${np.std(t_wf_data_steps):.2f}")

plt.subplot(3, 3, 3)
plt.plot(t_wf_label_steps)
plt.title(f"WF label mean={np.mean(t_wf_label_steps):.2f}$\pm${np.std(t_wf_label_steps):.2f}")

""" local times """
plt.subplot(3, 3, 4)
plt.plot(t_learner_data_steps)
plt.title(f"Learner data mean={np.mean(t_learner_data_steps):.2f}$\pm${np.std(t_learner_data_steps):.2f}")

plt.subplot(3, 3, 5)
plt.plot(t_learner_label_steps)
plt.title(f"Learner label mean={np.mean(t_learner_label_steps):.2f}$\pm${np.std(t_learner_label_steps):.2f}")

plt.subplot(3, 3, 6)
plt.plot(t_learner_backward_steps)
plt.title(f"Learner backward mean={np.mean(t_learner_backward_steps):.2f}$\pm${np.std(t_learner_backward_steps):.2f}")

""" local times """
plt.subplot(3, 3, 7)
plt.plot(t_learner_data_steps)
plt.title(f"Learner aux data & backward mean={np.mean(t_aux_hdl_learner_data_back_steps):.2f}$\pm${np.std(t_aux_hdl_learner_data_back_steps):.2f}")

plt.subplot(3, 3, 8)
plt.plot(t_learner_label_steps)
plt.title(f"Learner aux label mean={np.mean(t_aux_hdl_learner_label_steps):.2f}$\pm${np.std(t_aux_hdl_learner_label_steps):.2f}")

#plt.show()


""" compute times """
timings = {
    "time": [],
    "step": [],
    "location": [],
}
times = [t_wf_train_steps, t_wf_data_steps, t_wf_label_steps, t_learner_data_steps, t_learner_label_steps, t_learner_backward_steps, t_learner_data_steps, t_learner_label_steps]
steps = ["WF Total train", "WF data", "WF label", "Learner data", "Learner label", "Learner backward", "Learner aux data & backward", "Learner aux label"]
locations = ["server", "server", "server", "site-1", "site-2", "site-1", "site-1", "site-2"]
for _times, _s, _loc in zip(times, steps, locations):
    for _t in _times:
        timings["time"].append(_t)
        timings["step"].append(_s)
        timings["location"].append(_loc)

plt.figure()
sns.barplot(x="step", y="time", hue="location", data=timings)
plt.xticks(rotation=15)
plt.grid()
plt.xlabel("Compute step")
plt.ylabel("Time (sec)")
#plt.tight_layout()
#plt.show()

""" communication times """
timings = {
    "time": [],
    "step": [],
    "location": [],
}
times = [t_comm_data_steps, t_comm_label_steps, t_comm_activ_steps, t_comm_grad_steps]
steps = ["data", "label", "act", "grad"]
locations = ["server->site-1", "server->site-2", "site-1->server", "site-2->server"]
for _times, _s, _loc in zip(times, steps, locations):
    for _t in _times:
        timings["time"].append(_t)
        timings["step"].append(_s)
        timings["location"].append(_loc)

plt.figure()
sns.barplot(x="step", y="time", hue="location", data=timings)
plt.xticks(rotation=15)
plt.grid()
plt.xlabel("Communication step")
plt.ylabel("Time (sec)")
#plt.tight_layout()
plt.show()
