#!/usr/bin/env python3

from stream_processor import BufferStream
from helpers import (
    connect_to_influxdb, datetime_iter, get_jitter, ssim_index_to_db, get_ssim_index,
    get_abr_cc, query_measurement, retrieve_expt_config, connect_to_postgres)
import matplotlib.pyplot as plt
import os
import sys
import argparse
import yaml
import json
import time
from datetime import datetime, timedelta
import numpy as np
import matplotlib
matplotlib.use('Agg')


backup_hour = 11  # back up at 11 AM (UTC) every day
date_format = '%Y-%m-%dT%H:%M:%SZ'

args = None
expt = {}
influx_client = None
postgres_cursor = None

g_rebuffer = {}


def do_collect_ssim(d):
    video_acked_results = query_measurement(
        influx_client, 'video_acked', None, None)['video_acked']

    for pt in video_acked_results:
        expt_id = str(pt['expt_id'])
        expt_config = retrieve_expt_config(expt_id, expt, postgres_cursor)
        abr_cc = get_abr_cc(expt_config)

        if abr_cc not in d:
            d[abr_cc] = [0.0, 0]  # sum, count

        ssim_index = get_ssim_index(pt)
        if ssim_index is not None and ssim_index != 1:
            d[abr_cc][0] += ssim_index
            d[abr_cc][1] += 1


def collect_ssim():
    d = {}  # key: abr_cc; value: [sum, count]

    do_collect_ssim(d)

    # calculate average SSIM in dB
    for abr_cc in d:
        if d[abr_cc][1] == 0:
            sys.stderr.write('Warning: {} does not have SSIM data\n'
                             .format(abr_cc))
            continue

        avg_ssim_index = d[abr_cc][0] / d[abr_cc][1]
        avg_ssim_db = ssim_index_to_db(avg_ssim_index)
        d[abr_cc] = avg_ssim_db

    return d


def do_collect_jitter(d):
    video_acked_results = query_measurement(
        influx_client, 'video_acked', None, None)['video_acked']

    pt1 = None
    for pt2 in video_acked_results:
        if pt1 is None:
            pt1 = pt2
            continue
        
        expt_id1 = str(pt1['expt_id'])
        expt_id2 = str(pt2['expt_id'])

        expt_config = retrieve_expt_config(expt_id1, expt, postgres_cursor)

        abr_cc = get_abr_cc(expt_config)

        if expt_id1 != expt_id2:
            pt1 = pt2
            continue

        if abr_cc not in d:
            d[abr_cc] = [0.0, 0]  # sum, count

        jitter = get_jitter(pt1, pt2)
        if jitter is not None:
            d[abr_cc][0] += jitter
            d[abr_cc][1] += 1


def collect_jitter():
    d = {}  # key: abr_cc; value: [sum, count]

    do_collect_jitter(d)

    # calculate average jitter in dB
    for abr_cc in d:
        if d[abr_cc][1] == 0:
            sys.stderr.write('Warning: {} does not have SSIM data\n'
                             .format(abr_cc))
            continue

        avg_jitter_index = d[abr_cc][0] / d[abr_cc][1]
        avg_jitter_db = ssim_index_to_db(avg_jitter_index)
        d[abr_cc] = (avg_jitter_db, d[abr_cc][1])

    return d


def process_rebuffer_session(session, s):
    # exclude too short sessions
    if s['play_time'] < 5:
        return

    expt_id = str(session[-1])
    expt_config = retrieve_expt_config(expt_id, expt, postgres_cursor)
    abr_cc = get_abr_cc(expt_config)

    global g_rebuffer
    if abr_cc not in g_rebuffer:
        g_rebuffer[abr_cc] = {}
        g_rebuffer[abr_cc]['total_play'] = 0
        g_rebuffer[abr_cc]['total_rebuf'] = 0

    g_rebuffer[abr_cc]['total_play'] += s['play_time']
    g_rebuffer[abr_cc]['total_rebuf'] += s['cum_rebuf']


def collect_rebuffer():
    buffer_stream = BufferStream(process_rebuffer_session)
    buffer_stream.process(influx_client)

    return g_rebuffer


def plot_ssim_rebuffer(ssim, rebuffer, jitter, ssim_coef, rebuf_coef):
    fig, ax = plt.subplots()
    title = '[{}, {}] (UTC)'.format(args.start_time, args.end_time)
    ax.set_title(title)
    ax.set_xlabel('Time spent stalled (%)')
    ax.set_ylabel('Average SSIM (dB)')
    ax.grid()

    for abr_cc in ssim:
        if abr_cc not in rebuffer:
            sys.stderr.write('Warning: {} does not exist in rebuffer\n'
                             .format(abr_cc))
            continue

        abr_cc_str = '{}+{}'.format(*abr_cc)

        total_rebuf = rebuffer[abr_cc]['total_rebuf']
        total_play = rebuffer[abr_cc]['total_play']
        rebuf_rate = total_rebuf / total_play

        abr_cc_str += '\n({:.1f}m/{:.1f}h)'.format(total_rebuf / 60,
                                                   total_play / 3600)

        x = rebuf_rate * 100  # %
        y = ssim[abr_cc]
        z, c = jitter[abr_cc]

        print(abr_cc_str, y - ssim_coef*z - rebuf_coef*total_rebuf/(total_play/2.002))
        ax.scatter(x, y)
        ax.annotate(abr_cc_str, (x, y))

    # clamp x-axis to [0, 100]
    xmin, xmax = ax.get_xlim()
    xmin = max(xmin, 0)
    xmax = min(xmax, 100)
    ax.set_xlim(xmin, xmax)
    ax.invert_xaxis()

    output = args.output
    fig.savefig(output, dpi=150, bbox_inches='tight')
    sys.stderr.write('Saved plot to {}\n'.format(output))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('yaml_settings')
    parser.add_argument('--from', dest='start_time',
                        help='datetime in UTC conforming to RFC3339',
                        default=None)
    parser.add_argument('--to', dest='end_time',
                        help='datetime in UTC conforming to RFC3339',
                        default=None)
    
    parser.add_argument('--expt', help='e.g., expt_cache.json')
    parser.add_argument('-o', '--output', required=True)
    global args
    args = parser.parse_args()

    with open(args.yaml_settings, 'r') as fh:
        yaml_settings = yaml.safe_load(fh)

    if args.expt is not None:
        with open(args.expt, 'r') as fh:
            global expt
            expt = json.load(fh)
    else:
        # create a Postgres client and perform queries
        postgres_client = connect_to_postgres(yaml_settings)
        global postgres_cursor
        postgres_cursor = postgres_client.cursor()

    # create an InfluxDB client and perform queries
    global influx_client
    influx_client = connect_to_influxdb(yaml_settings)

    # collect ssim and rebuffer
    ssim = collect_ssim()
    jitter = collect_jitter()
    rebuffer = collect_rebuffer()

    if not ssim or not rebuffer:
        sys.exit('Error: no data found in the queried range')

    rebuffer_coeff = yaml_settings["experiments"][0]["fingerprint"]["abr_config"]["rebuffer_length_coeff"]
    ssim_coeff = yaml_settings["experiments"][0]["fingerprint"]["abr_config"]["ssim_diff_coeff"]

    print('rebuffer_coeff', rebuffer_coeff)
    print('ssim_coeff', ssim_coeff)

    print(ssim)
    print(jitter)
    print(rebuffer)
    

    # plot ssim vs rebuffer
    plot_ssim_rebuffer(ssim, rebuffer, jitter, ssim_coeff, rebuffer_coeff)

    if postgres_cursor:
        postgres_cursor.close()


if __name__ == '__main__':
    main()
