#!/usr/bin/env python3

from stream_processor import BufferStream
from helpers import (
    connect_to_influxdb, datetime_iter, get_jitter, ssim_index_to_db, get_ssim_index,
    get_abr_cc, query_measurement, retrieve_expt_config, connect_to_postgres)
import matplotlib.pyplot as plt
import os
import sys
import argparse
import yaml
import json
import time
from datetime import datetime, timedelta
import numpy as np
import matplotlib
matplotlib.use('Agg')

args = None
influx_client = None
postgres_cursor = None
experiments = {}
total_qoe = {}
colors = {
    "Fugu": "tab:blue",
    "Wolfi": "tab:orange",
    "BBA": "tab:green",
    "RobustMPC": "tab:red",
    "MPC": "tab:purple"
}


def collect_ssim():
    video_acked_results = query_measurement(
        influx_client, 'video_acked', None, None)['video_acked']

    for pt in video_acked_results:
        expt_id = str(pt['expt_id'])
        
        if expt_id not in experiments:
            experiments[expt_id] = {}

        init_id = str(pt["init_id"])
        if init_id not in experiments[expt_id]:
            experiments[expt_id][init_id] = {
                "ssim": [],
                "rebuf": 0,
                "total_play": 0,
            }
        
        ssim_index = get_ssim_index(pt)
        experiments[expt_id][init_id]["ssim"].append(ssim_index)


def calc_jitter(ssim):
    ssim1 = np.array(ssim)
    ssim2 = np.array(ssim)
    jitter = np.abs(ssim1[1:] - ssim2[:-1])
    c = np.count_nonzero(jitter)
    return ssim_index_to_db(np.sum(jitter) / c)


def collect_rebuffer():
    def process_rebuffer_session(session, s):
        # exclude too short sessions
        if s['play_time'] < 5:
            return

        init_id, expt_id = str(session[1]), str(session[2])
        experiments[expt_id][init_id]["total_play"] += s["play_time"]
        experiments[expt_id][init_id]["rebuf"] += s["cum_rebuf"]

    buffer_stream = BufferStream(process_rebuffer_session)
    buffer_stream.process(influx_client)


def calc_qoe_per_file(ssim_coeff, rebuf_coeff):
    total_qoe = {}

    for expt_id in experiments:
        if expt_id not in total_qoe:
            total_qoe[expt_id] = []

        for init_id in experiments[expt_id]:
            config = experiments[expt_id][init_id]
            ssim = config["ssim"]
            jitter = calc_jitter(config["ssim"])
            # divide by the number of point to be agnostic to the play_time
            rebuf = config["rebuf"] / (config["total_play"] / 2.002)
            
            qoe = ssim_index_to_db(np.sum(ssim) / len(ssim))
            qoe -= ssim_coeff * jitter
            qoe -= rebuf_coeff * rebuf

            total_qoe[expt_id].append(qoe)

    return total_qoe


def plot(experiments):
    percentiles_to_measure = [0, 5, 10, 25, 50, 75, 90, 100]
    patterns = [ "/" , "\\" , "|" , "-" , "+" , "x", "o", "O", ".", "*" ]
    width = 0.15  # the width of the bars
    multiplier = 0

    fig, ax = plt.subplots(layout='constrained')
    x = np.arange(len(percentiles_to_measure))  # the label locations

    for expt_id, values in sorted(experiments.items()):
        # expt_config = retrieve_expt_config(expt_id, {}, postgres_cursor)
        # abr = get_abr_cc(expt_config)[0]
        abr = expt_id
        measurement = np.percentile(values, percentiles_to_measure)

        offset = width * multiplier
        rects = ax.bar(x + offset, measurement, width, label=abr, color=colors[abr])
        # ax.bar_label(rects, padding=5)
        multiplier += 1

    ax.set_xticks(x + width*2, percentiles_to_measure)
    ax.set_ylabel('QoE')
    ax.set_xlabel('Percentile')

    ax.legend(loc='upper center', ncol=len(experiments.keys()), bbox_to_anchor=(0.5, 1.15))
    fig.set_figwidth(10)
    fig.set_figheight(3)

    # ax.legend(ncol=len(experiments.keys()))

    fig.savefig('percentiles.png', dpi=150, bbox_inches='tight')


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--yaml_settings', default='src/settings.yml')
    parser.add_argument('--file', default=None)

    global args
    args = parser.parse_args()

    if args.file:
        with open(args.file, 'r') as fp:
            total_qoe = json.load(fp)
        plot(total_qoe)
        return

    with open(args.yaml_settings, 'r') as fh:
        yaml_settings = yaml.safe_load(fh)

    # create a Postgres client and perform queries
    postgres_client = connect_to_postgres(yaml_settings)
    global postgres_cursor
    postgres_cursor = postgres_client.cursor()

    # create an InfluxDB client and perform queries
    global influx_client
    influx_client = connect_to_influxdb(yaml_settings)

    rebuf_coeff = yaml_settings["experiments"][0]["fingerprint"]["abr_config"]["rebuffer_length_coeff"]
    ssim_coeff = yaml_settings["experiments"][0]["fingerprint"]["abr_config"]["ssim_diff_coeff"]

    collect_ssim()
    collect_rebuffer()
    total_qoe = calc_qoe_per_file(ssim_coeff, rebuf_coeff)
    plot(total_qoe)

    with open('result.json', 'w') as fp:
        json.dump(total_qoe, fp)

    if postgres_cursor:
        postgres_cursor.close()


if __name__ == '__main__':
    main()
