import altair as alt
alt.renderers.enable('default')

from rampwf.utils import assert_read_problem
    

def lookahead_figs(stats_df, stat, y_limits, label=None, 
                   n_lookahead=None, n_columns=3):
    """Long horizon stats figs.

    Parameters
    ----------
    stats_df : pandas.DataFrame
        (concatenation of) dataframe(s) returned by
        leaderboard.*_stats_table()
    stat : str
        the name of the stat {r2, bias2, variance, ks}
    y_limits : list of two elements
        the y axes limits in the figures
    label : str or None
        the figure label, if None, identical to stat
    n_lookahead : int
        the maximum lookahead (limit of the x axes)
    n_columns : int
        the number of figures in a single row

    Returns
    -------
    an altair figure
    """
    problem = assert_read_problem()
    figs = []
    if not n_lookahead is None:
        stats_df = stats_df[stats_df['lookahead'] <= n_lookahead]
    for output_name in problem._target_column_observation_names:
        stat_fig = alt.Chart(stats_df).mark_line(size=3).encode(
            x=alt.X('lookahead:Q', title='number of steps in the future'),
            y=alt.Y(
                f'{output_name}_{stat}_mean:Q',
                title=f'{output_name} {label}',
                scale=alt.Scale(domain=y_limits, zero=False)),
            color=alt.Color('method:N', scale=alt.Scale(scheme='category10')),
        )
        error_band_fig = alt.Chart(stats_df).mark_area(opacity=0.2).encode(
            x='lookahead:Q',
            y=f'{output_name}_{stat}_cil:Q',
            y2=f'{output_name}_{stat}_ciu:Q',
            color=alt.Color('method:N', scale=alt.Scale(scheme='category10'),
                legend=None),
        )
        fig = (stat_fig + error_band_fig).resolve_legend(
            color='independent').properties(
            width=150,
            height=150
        )
        figs.append(fig)

    figs_by_two = [alt.hconcat(*figs[i:i + n_columns])\
                   for i in range(0, len(figs), n_columns)]
    final_fig = alt.vconcat(*figs_by_two).configure_view(
        strokeOpacity=0
    ).configure_axis(
        labelFontSize=10,
        titleFontSize=10,
    ).configure_legend(
        titleColor='black', titleFontSize=10, labelFontSize=10,
        labelLimit=1300, symbolSize=50, symbolStrokeWidth=3,
    )
    return final_fig