from matplotlib.ticker import FuncFormatter
import matplotlib.pyplot as plt
import numpy as np
import datetime
import os
import argparse
import random
import shutil
import pickle

formatter = FuncFormatter(lambda y, _: '{:.16g}'.format(y))


TIME_ORIGIN = datetime.datetime.utcfromtimestamp(0)

BYTES_IN_MB = 1000000.0 / 8.0


def validate_dir(dir):
    if os.path.exists(dir):
        shutil.rmtree(dir)

    subdirs = dir.split('/')
    dir = '.'
    for subdir in subdirs:
        dir += '/' + subdir
        if os.path.exists(dir):
            continue

        os.mkdir(dir)


def get_file_throughput(f):
    t = []
    with open(f, 'r') as file:
        for line in file.readlines():
            t.append(float(line))
    return t[:int(10*60/5)]


def get_file_variance(f):
    t = get_file_throughput(f)

    t = np.array(t)
    diffs = np.diff(t).tolist()
    variance = []
    for d in diffs:
        if len(variance) > 0 and variance[-1]*d > 0:
            # same sign
            variance[-1] += d
        else:
            variance.append(d)

    variance = np.array(variance)
    return np.mean(np.abs(variance)) / BYTES_IN_MB


def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx


def plot_tput_distr(tputs):
    fig, ax = plt.subplots()
    ax.set_xlabel('Throughput (Mbps)')
    ax.set_ylabel('CDF')

    tputs = np.array(tputs)
    tputs = np.clip(tputs, 0, 15)

    count, bins = np.histogram(tputs, bins=100)
    pdf = count / float(sum(count))

    cdf = np.cumsum(pdf)

    ax.semilogx(bins[1:], cdf, label='FCC traces')
    ax.set_ylim(0, 1)
    ax.set_xlim(0.1)
    ax.legend()

    ax.xaxis.set_major_formatter(formatter)
    fig.savefig('./throughput.png', bbox_inches='tight')


def plot_variance_distr(vars):
    #### avg variance for each file ####
    # hist, bins = np.histogram(vars, 100)

    # fig, ax = plt.subplots()
    # ax.set_xlabel('Mean file variance')
    # ax.set_ylabel('Num of files')

    # ax.plot(bins[1:], hist)
    # plt.savefig('./variance.png', bbox_inches='tight')

    #### dist of variance across the files (diff between timeslots) ####
    fig, ax = plt.subplots()
    ax.set_xlabel('Diff (Mbps)')
    ax.set_ylabel('CDF')

    vars = np.array(vars)
    vars = np.clip(vars, -10, 10)

    count, bins = np.histogram(vars, bins=np.arange(-10, 10))
    pdf = count / float(sum(count))

    cdf = np.cumsum(pdf)

    ax.plot(bins[1:], count, label='FCC traces')
    ax.legend()

    ax.xaxis.set_major_formatter(formatter)
    fig.savefig('./variance.png', bbox_inches='tight')


def get_throughput_in_mbps(input):
    # each line is the throughput of 5 seconds
    throughput = input / BYTES_IN_MB
    return throughput


def read_csv(traces_file, pickle_out):
    bw_measurements = {}
    NUM_LINES = np.inf

    line_counter = 0
    with open(traces_file, 'rb') as f:
        f.readline()
        for line in f:
            try:
                parse = line.decode('UTF-8').split(',')
                uid = parse[0]
                dtime = (datetime.datetime.strptime(parse[1], '%Y-%m-%d %H:%M:%S')
                         - TIME_ORIGIN).total_seconds()
                target = parse[2]
                address = parse[3]
                throughput = float(parse[6])  # bytes per second

                k = (uid, target)

                if k in bw_measurements:
                    bw_measurements[k].append(throughput)
                else:
                    bw_measurements[k] = [throughput]

                line_counter += 1
                if line_counter >= NUM_LINES:
                    break
            except:
                continue

    with open(pickle_out, 'wb') as file:
        pickle.dump(bw_measurements, file)


def create_cooked(pickle_file, output_dir):
    with open(pickle_file, 'rb') as file:
        bw_measurements = pickle.load(file)

    traces = []
    variances = []
    bw_keys = bw_measurements.keys()

    for k in bw_keys:
        throughputs = bw_measurements[k]
        t = np.array(throughputs)
        t = get_throughput_in_mbps(t)
        avg_throughput, min_throughput = t.mean(), t.min()

        if avg_throughput > 6 or min_throughput < 0.2:
            # print('skipped bandwidth {}, {}'.format(
            #     avg_throughput, min_throughput))
            continue

        if len(throughputs)*5 < 600:
            # print('skipped short file {}'.format(len(throughputs)*5))
            continue

        traces += t.tolist()
        variances.append(np.mean(np.abs(np.diff(t))).tolist())

        out_file = 'trace_' + '_'.join(k)
        out_file = out_file.replace(':', '-')
        out_file = out_file.replace('/', '-')

        out_file = output_dir + out_file
        with open(out_file, 'wb') as f:
            for i in bw_measurements[k]:
                f.write((str(i) + '\n').encode('UTF-8'))

    plot_tput_distr(traces)
    plot_variance_distr(variances)


def create_variance_buckets(cooked_dir, files):
    BUCKETS = np.arange(0, 5, 0.5)
    buckets = [[] for _ in range(len(BUCKETS))]
    for f in files:
        file_variance = get_file_variance(cooked_dir + f)
        bucket_idx = find_nearest(BUCKETS, file_variance)
        buckets[bucket_idx].append(f)

    print(BUCKETS)
    print([len(b) for b in buckets])
    return buckets


def create_avg_buckets(cooked_dir, files):
    BUCKETS = np.arange(0, 3.5, 0.3)
    buckets = [[] for _ in range(len(BUCKETS))]
    for f in files:
        t = get_file_throughput(cooked_dir + f)
        t = np.array(t)
        t = get_throughput_in_mbps(t)
        avg_t = np.median(t)
        bucket_idx = find_nearest(BUCKETS, avg_t)
        buckets[bucket_idx].append(f)

    print(BUCKETS)
    print([len(b) for b in buckets])
    return buckets


def create_random_subset(files, cooked_dir, output_dir, num_of_files=None):
    random.shuffle(files)
    if num_of_files is None:
        num_of_files = len(files)

    subset = []
    traces = []
    variances = []
    count = 0
    for f in files:
        if count > num_of_files - 1:
            break

        t = get_file_throughput(cooked_dir + f)
        t = np.array(t)
        t = get_throughput_in_mbps(t)
        # traces += t.tolist()
        traces.append(np.median(t))
        variances += np.diff(t).tolist()

        if count < int(0.2*num_of_files):
            out_dir = output_dir + 'test/'
        else:
            out_dir = output_dir + 'train/'

        shutil.copy(cooked_dir + f, out_dir)

        subset.append(f)
        count += 1

    plot_tput_distr(traces)
    plot_variance_distr(variances)

    return subset


def filter_files(cooked_dir, files):
    MIN_VARIANCE = 0.01
    filtered_files = []
    for f in files:
        t = get_file_throughput(cooked_dir + f)
        t = np.array(t)
        t = get_throughput_in_mbps(t)
        avg_t = np.median(t)
        var = np.mean(np.abs(np.diff(t))).tolist()

        if avg_t > 3:
            # print('skipped bandwidth {}'.format(avg_t))
            continue

        # if MIN_VARIANCE is not None and var < MIN_VARIANCE:
        #     print('skipped variance {}'.format(var))
        #     continue

        filtered_files.append(f)

    return filtered_files


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Traces converter")
    parser.add_argument(
        "--csv",
        default='./csv/curr_webget_2022_01.csv',
        help='traces file'
    )
    parser.add_argument(
        "--pickle",
        default='./pickles/curr_webget_2022_01.pickle',
        help='traces file'
    )
    parser.add_argument(
        "--cooked",
        default='./cooked/cooked_2022_01/',
        help='cooked traces dir'
    )
    parser.add_argument(
        "--output",
        default='./traces/',
        help='cooked traces dir'
    )
    parser.add_argument(
        "-s",
        type=int,
        default=2,
        help='step to begin'
    )
    parser.add_argument(
        "-c",
        type=int,
        default=500,
        help='num of traces'
    )
    args = parser.parse_args()

    if args.s <= 0:
        read_csv(args.csv, args.pickle)
    if args.s <= 1:
        validate_dir(args.cooked)
        create_cooked(args.pickle, args.cooked)
    if args.s <= 2:
        validate_dir(args.output + 'train')
        validate_dir(args.output + 'test')
        files = os.listdir(args.cooked)

        # filtered_files = filter_files(args.cooked, files)
        # create_random_subset(filtered_files, args.cooked, args.output, args.c)

        files = filter_files(args.cooked, files)
        print("debug", len(files))
        # buckets = create_variance_buckets(args.cooked, files)

        # buckets = create_avg_buckets(args.cooked, files)

        # part_of_bucket_0 = buckets[0] + buckets[1] + buckets[2]+buckets[3]
        # other_buckets = [item for sublist in buckets[4:] for item in sublist]
        # random.shuffle(part_of_bucket_0)
        # random.shuffle(other_buckets)

        # files =  part_of_bucket_0 + other_buckets
        subset = create_random_subset(files, args.cooked, args.output, args.c)

        create_avg_buckets(args.cooked, subset)

        buckets = create_variance_buckets(args.cooked, subset)
