"""Python methods for Counterfactual Inference with SpaRTeN
"""

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


def cfi_file(df:pd.DataFrame, out_path:str):
    """Generate a cfi_file from the data frame. 

    Args:
        df (pd.DataFrame): DataFrame containing relevant CFI variables. 
        out_path (str): String containing the an outfile path for CFI files. 
    """
    df.to_csv(out_path, index=False)


def cfi_scatter(in_path:str, out_path: str):
    """Generate a scatter plot for counter-factual inference.

    Args:
        in_path (str): Path to a counter-factual inference dataframe. 
        out_path (str): Output path for string 
    """
    df = pd.read_csv(in_path)
    df_last = df[df.label == df.label.max()]
    plt.close()
    sns.lineplot(data=df_last, x = 'step', y = 'best_preds', label='preds')
    sns.lineplot(data=df_last, x = 'step', y = 'ci_preds', label='CI')
    sns.lineplot(data=df_last, x = 'step', y='value', label='observed')
    plt.legend()
    plt.ylabel('value')
    plt.savefig(out_path)