import pandas as pd
import statsmodels.api as sm
import numpy as np

pd.options.mode.chained_assignment = None
from sklearn.metrics import r2_score


def train_glm_plot_predict_example(fpr, df_train, df_test, predict_difference=True):

    # fit model
    df_train = df_train[df_train[fpr] > float(fpr)]
    df_test = df_test[df_test[fpr] > float(fpr)]

    X = df_train[["n_classes", "shots"]]
    X["shots"] = X["shots"].apply(lambda x: np.log10(x))
    X["n_classes"] = X["n_classes"].apply(lambda x: np.log10(x))
    X = sm.add_constant(X)

    # predict (tpr-fpr) or (tpr)
    if predict_difference:
        y = df_train[fpr].apply(lambda x: np.log10(x - float(fpr)))
    else:
        y = df_train[fpr].apply(lambda x: np.log10(x))

    model = sm.OLS(y, X)
    results = model.fit()

    coefficients = results.params

    # compute r2_test
    df_test_r50 = df_test[df_test["feature_extractor"] != "vit-b-16"]

    df_test_r50["shots_log"] = df_test_r50["shots"].apply(lambda x: np.log10(x))
    df_test_r50["n_classes_log"] = df_test_r50["n_classes"].apply(lambda x: np.log10(x))

    df_test_r50["y_predict"] = 10 ** (
        df_test_r50["shots_log"] * coefficients["shots"]
        + df_test_r50["n_classes_log"] * coefficients["n_classes"]
        + coefficients["const"]
    )

    if predict_difference:
        df_test_r50["y_predict"] = df_test_r50["y_predict"] + float(fpr)

    r2_test = r2_score(df_test_r50["y_predict"].values, df_test_r50[fpr].values)

    return results.summary(), r2_test


if __name__ == "__main__":

    df_shots = pd.read_csv("processed_data.csv")

    df_shots_r50 = df_shots[df_shots["feature_extractor"] != "vit-b-16"]
    df_shots_vit = df_shots[df_shots["feature_extractor"] == "vit-b-16"]
    fprs = ["0.1", "0.01", "0.001", "0.0001", "0.00001"]

    r2s_test = list()
    summaries = list()

    for i_fpr, fpr in enumerate(fprs):
        summary, r2_test = train_glm_plot_predict_example(
            fpr=fpr,
            df_train=df_shots_vit,
            df_test=df_shots_r50,
            predict_difference=True
        )
        summaries.append(summary)
        r2s_test.append(r2_test)
        print(f"fpr {fpr}, r2_test {r2_test}")
