#!/usr/bin/env python3

from stream_processor import BufferStream
from helpers import (
    connect_to_influxdb, datetime_iter, get_jitter, ssim_index_to_db, get_ssim_index,
    get_abr_cc, query_measurement, retrieve_expt_config, connect_to_postgres)
import matplotlib.pyplot as plt
import os
import sys
import argparse
import yaml
import json
import time
from datetime import datetime, timedelta
import numpy as np
import matplotlib
matplotlib.use('Agg')


backup_hour = 11  # back up at 11 AM (UTC) every day
date_format = '%Y-%m-%dT%H:%M:%SZ'

args = None
expt = {}
influx_client = None
postgres_cursor = None

g_rebuffer = {}

def kB_to_Mbps(kB):
    return ((kB / 2.002) * 8 / 1000) / 1000

def do_collect_bitrate(d):
    video_acked_results = query_measurement(
        influx_client, 'video_sent', None, None)['video_sent']

    for pt in video_acked_results:
        expt_id = str(pt['expt_id'])
        expt_config = retrieve_expt_config(expt_id, expt, postgres_cursor)
        abr_cc = get_abr_cc(expt_config)

        if abr_cc not in d:
            d[abr_cc] = [0.0, 0]  # sum, count

        bitrate = kB_to_Mbps(float(pt["size"]))
        if bitrate is not None and bitrate != 1:
            d[abr_cc][0] += bitrate
            d[abr_cc][1] += 1


def collect_bitrate():
    d = {}  # key: abr_cc; value: [sum, count]

    do_collect_bitrate(d)

    # calculate average SSIM in dB
    for abr_cc in d:
        if d[abr_cc][1] == 0:
            sys.stderr.write('Warning: {} does not have SIZE data\n'
                             .format(abr_cc))
            continue

        avg_bitrate_index = d[abr_cc][0] / d[abr_cc][1]
        d[abr_cc] = avg_bitrate_index

    return d


def do_collect_jitter(d):
    video_acked_results = query_measurement(
        influx_client, 'video_sent', None, None)['video_sent']

    pt1 = None
    for pt2 in video_acked_results:
        if pt1 is None:
            pt1 = pt2
            continue
        
        expt_id1 = str(pt1['expt_id'])
        expt_id2 = str(pt2['expt_id'])

        expt_config = retrieve_expt_config(expt_id1, expt, postgres_cursor)

        abr_cc = get_abr_cc(expt_config)

        if expt_id1 != expt_id2:
            pt1 = pt2
            continue

        if abr_cc not in d:
            d[abr_cc] = [0.0, 0]  # sum, count

        jitter = abs(kB_to_Mbps(float(pt1['size'])) - kB_to_Mbps(float(pt2['size'])))
        if jitter is not None:
            d[abr_cc][0] += jitter
            d[abr_cc][1] += 1


def collect_jitter():
    d = {}  # key: abr_cc; value: [sum, count]

    do_collect_jitter(d)

    # calculate average jitter in dB
    for abr_cc in d:
        if d[abr_cc][1] == 0:
            sys.stderr.write('Warning: {} does not have SIZE data\n'
                             .format(abr_cc))
            continue

        avg_jitter_index = d[abr_cc][0] / d[abr_cc][1]
        d[abr_cc] = (avg_jitter_index, d[abr_cc][1])

    return d


def process_rebuffer_session(session, s):
    # exclude too short sessions
    if s['play_time'] < 5:
        return

    expt_id = str(session[-1])
    expt_config = retrieve_expt_config(expt_id, expt, postgres_cursor)
    abr_cc = get_abr_cc(expt_config)

    global g_rebuffer
    if abr_cc not in g_rebuffer:
        g_rebuffer[abr_cc] = {}
        g_rebuffer[abr_cc]['total_play'] = 0
        g_rebuffer[abr_cc]['total_rebuf'] = 0

    g_rebuffer[abr_cc]['total_play'] += s['play_time']
    g_rebuffer[abr_cc]['total_rebuf'] += s['cum_rebuf']


def collect_rebuffer():
    buffer_stream = BufferStream(process_rebuffer_session)
    buffer_stream.process(influx_client)

    return g_rebuffer


def plot_ssim_rebuffer(ssim, rebuffer, jitter, ssim_coef, rebuf_coef):
    fig, ax = plt.subplots()
    title = '[{}, {}] (UTC)'.format(args.start_time, args.end_time)
    ax.set_title(title)
    ax.set_xlabel('Time spent stalled (%)')
    ax.set_ylabel('Average bitrate (kbps)')
    ax.grid()

    for abr_cc in ssim:
        if abr_cc not in rebuffer:
            sys.stderr.write('Warning: {} does not exist in rebuffer\n'
                             .format(abr_cc))
            continue

        abr_cc_str = '{}+{}'.format(*abr_cc)

        total_rebuf = rebuffer[abr_cc]['total_rebuf']
        total_play = rebuffer[abr_cc]['total_play']
        rebuf_rate = total_rebuf / total_play

        abr_cc_str += '\n({:.1f}m/{:.1f}h)'.format(total_rebuf / 60,
                                                   total_play / 3600)

        x = rebuf_rate * 100  # %
        y = ssim[abr_cc]
        z, c = jitter[abr_cc]

        print(abr_cc_str, y - ssim_coef*z - rebuf_coef*total_rebuf/(total_play/2.002))
        ax.scatter(x, y)
        ax.annotate(abr_cc_str, (x, y))

    # clamp x-axis to [0, 100]
    xmin, xmax = ax.get_xlim()
    xmin = max(xmin, 0)
    xmax = min(xmax, 100)
    ax.set_xlim(xmin, xmax)
    ax.invert_xaxis()

    output = args.output
    fig.savefig(output, dpi=150, bbox_inches='tight')
    sys.stderr.write('Saved plot to {}\n'.format(output))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('yaml_settings')
    parser.add_argument('--from', dest='start_time',
                        help='datetime in UTC conforming to RFC3339',
                        default=None)
    parser.add_argument('--to', dest='end_time',
                        help='datetime in UTC conforming to RFC3339',
                        default=None)
    
    parser.add_argument('--expt', help='e.g., expt_cache.json')
    parser.add_argument('-o', '--output', required=True)
    global args
    args = parser.parse_args()

    with open(args.yaml_settings, 'r') as fh:
        yaml_settings = yaml.safe_load(fh)

    if args.expt is not None:
        with open(args.expt, 'r') as fh:
            global expt
            expt = json.load(fh)
    else:
        # create a Postgres client and perform queries
        postgres_client = connect_to_postgres(yaml_settings)
        global postgres_cursor
        postgres_cursor = postgres_client.cursor()

    # create an InfluxDB client and perform queries
    global influx_client
    influx_client = connect_to_influxdb(yaml_settings)

    # collect ssim and rebuffer
    ssim = collect_bitrate()
    jitter = collect_jitter()
    rebuffer = collect_rebuffer()

    if not ssim or not rebuffer:
        sys.exit('Error: no data found in the queried range')

    rebuffer_coeff = yaml_settings["experiments"][0]["fingerprint"]["abr_config"]["rebuffer_length_coeff"]
    ssim_coeff = yaml_settings["experiments"][0]["fingerprint"]["abr_config"]["ssim_diff_coeff"]

    print('rebuffer_coeff', rebuffer_coeff)
    print('ssim_coeff', ssim_coeff)

    print(ssim)
    print(jitter)
    print(rebuffer)
    

    # plot ssim vs rebuffer
    plot_ssim_rebuffer(ssim, rebuffer, jitter, ssim_coeff, rebuffer_coeff)

    if postgres_cursor:
        postgres_cursor.close()


if __name__ == '__main__':
    main()
