#!/usr/bin/env python

import argparse
import glob
import matplotlib.pyplot as plt
import os
import pandas as pd
import seaborn as sns
sns.set_theme(style="white")


def parse_csv(path_pattern, labels):
    df = pd.DataFrame()
    path_list_ML = glob.glob(path_pattern)
    path_list_ML.sort()
    for path, label in zip(path_list_ML, labels):
        df_cur = pd.read_csv(path)
        df_cur['alg'] = label
        df = pd.concat([df, df_cur], axis=0)
    df.index = range(len(df))
    return df


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'algorithm', help='the name of the algorithm: greedy, random, stochastic, or interlace', type=str)
    parser.add_argument(
        'dataset_name', help='the name of the dataset: wishart, wishart_fixed_k, movie_lens, or netflix', type=str)
    parser.add_argument(
        'input_matrix', help='the type of input matrix: L or B', type=str)
    parser.add_argument(
        'x_axis', help='the value displayed on x-axis: k, n', type=str)
    parser.add_argument(
        'y_axis', help='the value displayed on y-axis: time, computed_offdiagonals_V, value', type=str)
    parser.add_argument(
        '--input', help='path to the directory where the raw data is located. default: ../result/', default='../result/', type=str)
    parser.add_argument(
        '--output', help='path to the directory where the output .txt file is generated. default: ./graph/', default=os.path.dirname(__file__), type=str)
    args = parser.parse_args()
    path_to_data = args.input
    if path_to_data[-1] != '/':
        path_to_data += '/'
    path_to_out = args.output
    if path_to_out[-1] != '/':
        path_to_out += '/'
    path_to_out += 'graph/'
    alg = args.algorithm
    dataset_name = args.dataset_name
    input_matrix = args.input_matrix
    x_axis = args.x_axis
    y_axis = args.y_axis

    labels = ['LazyFast', 'Lazy', 'Fast', 'Naive']
    path_pattern = f'{path_to_data}{alg}/{dataset_name}/*-{input_matrix}*.csv'
    df = parse_csv(path_pattern, labels)
    fig = plt.figure()
    graph = sns.lineplot(data=df, x=x_axis, y=y_axis,
                         style='alg', hue='alg', palette='tab10')
    graph.legend()

    os.makedirs(path_to_out, exist_ok=True)
    fig.savefig(
        f'{path_to_out}{alg}_{dataset_name}_{y_axis}_{input_matrix}.pdf')


if __name__ == '__main__':
    main()
