import os
from os.path import join

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest
import seaborn as sns
from calibration._utils import save_fig
from calibration._plot import plot_Q_vs_S_ex, plot_QS_1D_ex, set_latex_font, plot_GL_bounds

from .ClassificationExample import Link1DExample, Squares1DExample, Slope1DExample, Slopes1DExample


def append_metrics_to_buffer(out, ex, filename='metrics.csv', **kwargs):
    kwargs.update({
        'GL': ex.GL(),
        'UB_known': ex.UB_known(),
        'UB_ER': ex.UB_ER(),
        'UB_acc_sq1': ex.UB_acc_sq1(),
        'UB_acc_sq2': ex.UB_acc_sq2(),
        'UB_acc_sq': ex.UB_acc_sq(),
        'acc_bayes': ex.acc_bayes(),
        'acc_s': ex.acc_s(),
        'ER': ex.ER(),
        'acc_s_wrt_q': ex.acc_s_wrt_q(),
    })

    df = pd.DataFrame(kwargs, index=[0])
    filepath = join(out, filename)
    _df = pytest.buffer_dict.get(filepath, pd.DataFrame())
    df = pd.concat([_df, df])
    os.makedirs(out, exist_ok=True)
    pytest.buffer_dict[filepath] = df
    df.to_csv(filepath)


def plot_QS(out, ex, x, n=1000):
    set_latex_font()
    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    plot_Q_vs_S_ex(ex, n, ax=axes[0])
    plot_QS_1D_ex(ex, ax=axes[1], N=1000)
    save_fig(fig, out, x=f'{x:.3g}')


@pytest.mark.parametrize('alpha', np.linspace(0, 1, 10))
def test_link1D(alpha, out):
    ex = Link1DExample(alpha=alpha, link='sin')
    append_metrics_to_buffer(out, ex, alpha=alpha)
    plot_QS(out, ex, x=alpha)

@pytest.mark.parametrize('alpha', [1])
def test_link1D_sin(alpha, out):
    ex = Link1DExample(alpha=alpha, link='poly', s='sin')
    append_metrics_to_buffer(out, ex, alpha=alpha)
    plot_QS(out, ex, x=alpha)


def test_squares1D(out):
    ex = Squares1DExample()
    append_metrics_to_buffer(out, ex, lol='lol')
    plot_QS(out, ex, x=1)


@pytest.mark.parametrize('theta', np.linspace(-np.pi/2, np.pi/2, 41))
def test_slope1D(theta, out):
    ex = Slope1DExample(alpha=theta)
    append_metrics_to_buffer(out, ex, theta=theta/np.pi)
    plot_QS(out, ex, x=theta)


@pytest.mark.parametrize('theta', np.linspace(0, np.pi/2, 41))
def test_slopes1D(theta, out):
    ex = Slopes1DExample(alpha=theta)
    append_metrics_to_buffer(out, ex, theta=theta/np.pi)
    plot_QS(out, ex, x=theta)


@pytest.mark.parametrize('name', [
    ('link1D', 'alpha'),
    ('squares1D', 'lol'),
    ('slope1D', 'theta'),
])
def test_plot(name, out):
    name, x = name
    parent_dir = os.path.dirname(out)
    input = os.path.join(parent_dir, name)
    df = pd.read_csv(os.path.join(input, 'metrics.csv'))
    fig = plot_GL_bounds(df, x)
    save_fig(fig, out, ex=name, x=x)
