import numpy as np
import statsmodels.api as sm
from scipy.stats import linregress


def format_p_value(p_value):
    if 0.05 < p_value <= 1:
        return 'p: ns\n'
    elif 0.01 < p_value <= 0.05:
        return 'p <= 0.05 (*)\n'
    elif 0.001 < p_value <= 0.01:
        return 'p <= 0.01 (**)\n'
    elif 0.0001 < p_value <= 0.001:
        return 'p <= 0.001 (***)\n'
    elif p_value <= 0.0001:
        return 'p <= 0.0001 (****)\n'
    else:
        return 'p: Invalid\n'


def print_regression_info(x, y):
    slope, intercept, r_value, p_value, _ = linregress(x, y)

    print(f'y = {slope:.1f}x + {intercept:.1f}')
    print(f'R^2: {r_value ** 2:.2f}')
    print(format_p_value(p_value))


def quadratic_fit(x, y, verbose=False):
    x = np.array(x).reshape(-1, 1)
    y = np.array(y)

    X = x ** 2
    X = sm.add_constant(X)  # b + ax**2
    model = sm.OLS(y, X).fit()
    if verbose:
        print(model.summary())

    return model


def quadplot(x, model, ax, **params):
    x_range = np.linspace(min(x), max(x), 100)
    X_pred = x_range ** 2
    X_pred = np.array(sm.add_constant(X_pred))

    y_pred = model.predict(X_pred)
    ci = model.get_prediction(X_pred).conf_int(alpha=0.05)  # 95% CI

    ax.plot(x_range, y_pred, linewidth=2, **params)
    ax.fill_between(x_range, ci[:, 0], ci[:, 1], alpha=0.2, **params)


def print_quadratic_info(x, y):
    model = quadratic_fit(x, y)

    intercept = model.params[0]
    beta2 = model.params[-1]
    r2_value = model.rsquared
    p_value = model.pvalues[-1]

    print(f'y = {beta2:.1f}x^2 + {intercept:.1f}')
    print(f'R: {np.sqrt(r2_value):.2f}')
    print(f'R^2: {r2_value:.2f}')
    print(format_p_value(p_value))
