"""A python file for utilizing SpaRTeN for zero-shot learning. 
"""

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt




def zsl_file(df: pd.DataFrame, outfile: str):
    """Generate csv file from pandas DataFrame

    Args:
        df (pd.DataFrame): zero-shot learning pandas DataFrame. 
        outfile (str): Path to an outfile for the pandas dataframe. 
    """
    df.to_csv(outfile, index=False)

def zsl_facet(in_path:str, out_path:str):
    """Generate a facet grid from the ZSL file. 

    Args:
        in_path (str): In-file containing the zsl_file. 
        out_path (str): Out-path to a facet grid. 
    """
    zsl_df = pd.read_csv(in_path)
    g = sns.FacetGrid(zsl_df, row="X", col="Y")
    g.map_dataframe(sns.lineplot, x="step", y="value", hue='label')
    plt.tight_layout()
    plt.savefig(out_path)

def zsl_scatter(in_path:str, out_path:str):
    """Generate a scatter plot from the ZSL file. 

    Args:
        in_path (str): path to zsl_file
        out_path (str): path to scatter plot. 
    """
    zsl_df = pd.read_csv(in_path)
    zsl_df['label'] = zsl_df['X'].apply(str) + ',' + zsl_df['Y'].apply(str)
    plt.close()
    sns.scatterplot(data=zsl_df, x='step', y='value', hue='label')
    plt.tight_layout()
    plt.savefig(out_path)