import os
import glob
import argparse
import sys
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from typing import List
import json

# For the legend
from matplotlib.lines import Line2D


sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

sns.set_style('whitegrid')


def custom_palette(n_colors):
    palette = [
        "#D4AC0D", # yellow
        "#953553", # red
        "#3498DB", # blue
        "#7DCEA0", # green 
        "#FA8072", # salmon
        "#9384C0", # purple
    ][:n_colors]

    return sns.color_palette(palette, len(palette))


def make_ax_violinplot(
        seaborn_df : List,
        ax : plt.Axes,
        metric : str,
        colors : List[str],
        title : str = None,
        x : str = "scenario",
):
    """Make violinplot on specified axes subplot.

    Make pandas DataFrame from records and graph violinplot at axes[row, col]. 
    """
    sns_ax = sns.boxplot(
        ax=ax, data=seaborn_df, x=x, y=metric, hue="algo", palette=colors, boxprops=dict(alpha=.5), width=0.7
    )

    if title is not None:
        sns_ax.set_title(title)

    sns_ax.spines["bottom"].set_color('black')
    sns.despine(left=True) # remove left border
    sns_ax.set_ylim(bottom=-0)
    return sns_ax


def make_titles(params_path):
    with open(params_path) as f:
        params = json.load(f)
    axtitle = params["scm"]
    if params["num_hidden"] > 0:
        figtitle = "Latent variables model"
    else:
        figtitle = "Fully observable model"
    return axtitle, figtitle


def make_plot_name(params_path):
    with open(params_path) as f:
        params = json.load(f)
    pdfname = []
    if params["num_hidden"] > 0:
        pdfname.append("latent")
    else:
        pdfname.append("observable")

    if params["p_edge"] == 0.3:
        pdfname.append("sparse")
    elif params["p_edge"] == 0.5:
        pdfname.append("dense")
    
    return "_".join(pdfname) + ".pdf"

if __name__ == "__main__":

    # Command linear arguments
    parser = argparse.ArgumentParser(description='Experiment with increasing node size.')
    parser.add_argument('--algorithms', nargs='+', default=['scam', 'camuv'])
    parser.add_argument('--metric', default='shd')
    parser.add_argument('--scms', nargs='+', default=['linear', 'nonlinear']) # scms in the plot
    parser.add_argument('--p_edge', default=0.3, type=float) # sparsity of the plot
    parser.add_argument('--num_hidden', default=2, type=int) # confounded/not plot
    params = vars(parser.parse_args())

    algorithms = params["algorithms"]
    metric = params["metric"]

    # Get the dirs with data for the plots
    logs_dirs = glob.glob(os.path.join('.', 'logs', 'paper-plots', 'incr_size_*'))
    output_dir = os.path.join('.', 'logs', 'paper-plots')
    plots_dirs = list()
    reference_params_path = dict()
    for dir in logs_dirs:
        log_params_path = os.path.join(dir, "params.json")
        with open(log_params_path) as f:
            log_params = json.load(f)
        if ( # Directories with FCI make a mess out of this. TODO: add fci csv files inside the right dirs
            log_params["scm"] in params["scms"] and\
            log_params["num_hidden"] == params["num_hidden"] and\
            log_params["p_edge"] == params["p_edge"]
        ):
            plots_dirs.append(dir)
            reference_params_path = log_params_path

    if len(plots_dirs) == 0:
        raise ValueError("There are no data matching your requirements. Aborting plot")

    pdf_name = make_plot_name(reference_params_path)

    # Create axes
    fig, axes = plt.subplots(1, len(plots_dirs), figsize=(24, 8))
    for i in range(len(plots_dirs)):
        file_dir = plots_dirs[i]
        ax = axes[i]

        # Title and pdfname
        params_path = os.path.join(file_dir, "params.json")
        axtitle, figtitle = make_titles(params_path)

        # Combined logs in df
        dfs = []
        for algo in algorithms:
            file = algo + '.csv'
            df = pd.read_csv(os.path.join(file_dir, file), index_col=0)
            df['algo'] = algo  # Add a column to identify the source file
            dfs.append(df)
        combined_df = pd.concat(dfs)

        # Boxplot
        colors = custom_palette(len(algorithms))
        ax = make_ax_violinplot(combined_df, ax, metric, colors, title=axtitle, x="num_nodes")

        # labels
        ax.set_ylabel(metric, size=32)
        ax.set_xlabel("number of nodes", size=32)

        # ticks (increase xticks spacing)
        ax.set_yticklabels(ax.get_yticklabels(), fontsize = 28)
        ax.set_xticklabels(ax.get_xticklabels(), fontsize = 28)

        # Ax title
        ax.set_title(ax.get_title(), fontsize=32)
        ax.get_legend().remove()

    # title = fig.suptitle(figtitle, fontsize=34)
    # title.set_position([0.5, 0.98])
    fig.tight_layout(h_pad=5, w_pad=3) # h_pad add space between rows, w_pad between cols
    fig.subplots_adjust(top=.85) # top add space above the first row
    plt.savefig(os.path.join(output_dir, f'{metric}_{pdf_name}'))
    plt.close("all")

    # Make legend
    fig, ax = plt.subplots(1, 1, figsize=(20, 20))
    lines = legend = [Line2D([0], [0], color=c, lw=4) for c in colors] 
    labels = [algo if algo != "scam" else "adascore" for algo in algorithms] # map scam to adascore
    ax.clear() # Remove the axis data and labels
    legend = ax.legend(lines, labels, ncol=len(labels), loc='center', fontsize=17, borderaxespad=0)
    ax.axis('off') # Hide axes
    # Set alpha to legend lines
    for lh in legend.legend_handles: 
        lh.set_alpha(.7)
        
    fig.savefig((os.path.join(output_dir, f'legend.pdf')))
    plt.close("all")