import pandas as pd
import numpy as np
import torch

from sklearn import linear_model

import json
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn import preprocessing
from sklearn.metrics import roc_curve, auc

from itertools import product

from datasets import load_from_disk


def get_data(config, config_paired):
    train_path = config["embedding"]["save_embeddings_path"].format(
        part="train", **config
    )
    validation_path = config["embedding"]["save_embeddings_path"].format(
        part="validation", **config
    )
    ratings = load_from_disk(config["data_path"].format(**config))

    train_ratings = ratings["reward_one"].to_pandas()
    train_ratings["created_at_month"] = train_ratings["created_at"].apply(
        lambda x: x[:7]
    )

    valid_ratings = ratings["reward_valid"].to_pandas()
    valid_ratings["created_at_month"] = valid_ratings["created_at"].apply(
        lambda x: x[:7]
    )

    xs_train = torch.load(train_path + "_xs.pt")
    ys_train = torch.load(train_path + "_ys.pt")
    xs_valid = torch.load(validation_path + "_xs.pt")
    ys_valid = torch.load(validation_path + "_ys.pt")

    scaler_x = preprocessing.StandardScaler().fit(xs_train.float().numpy())
    scaler_y = preprocessing.StandardScaler().fit(
        ys_train.float().numpy().reshape(-1, 1)
    )
    xs_train = scaler_x.transform(xs_train.float().numpy())
    ys_train = scaler_y.transform(ys_train.float().numpy().reshape(-1, 1))
    xs_valid = scaler_x.transform(xs_valid.float().numpy())
    ys_valid = scaler_y.transform(ys_valid.float().numpy().reshape(-1, 1))

    test_path = config_paired["embedding"]["save_embeddings_path"].format(
        part="test", **config_paired
    )
    xs_chosen_test = torch.load(test_path + "_xs_chosen.pt")
    xs_rejected_test = torch.load(test_path + "_xs_rejected.pt")
    score_chosen_test = (
        torch.load(test_path + "_score_chosen.pt").float().numpy().reshape(-1, 1)
    )
    score_chosen_test = scaler_y.transform(score_chosen_test)
    score_rejected_test = (
        torch.load(test_path + "_score_rejected.pt").float().numpy().reshape(-1, 1)
    )
    score_rejected_test = scaler_y.transform(score_rejected_test)

    chosen_test = scaler_x.transform(xs_chosen_test.float().numpy())
    rejected_test = scaler_x.transform(xs_rejected_test.float().numpy())

    np.random.seed(42)
    ys_test = []
    for i in range(chosen_test.shape[0]):
        rnd = np.random.rand() < 0.5
        if rnd:
            ys_test.append(1)
        else:
            ys_test.append(0)
    ys_test = np.array(ys_test)

    return (
        (xs_train, ys_train, train_ratings),
        (xs_valid, ys_valid, valid_ratings),
        (chosen_test, rejected_test, score_chosen_test, score_rejected_test, ys_test),
    )


def fit_and_finetune(
    data_train, data_valid, data_test, model, params, plot=True, model_name=""
):
    xs_train, ys_train, train_ratings = data_train
    xs_valid, ys_valid, valid_ratings = data_valid
    (
        xs_chosen_test,
        xs_rejected_test,
        score_chosen_test,
        score_rejected_test,
        ys_test,
    ) = data_test
    run_info = {}
    train_ratings_gb = (
        train_ratings.groupby("created_at_month").agg({"ctr": "mean"}).reset_index()
    )
    param_keys = list(params.keys())
    param_values = list(product(*[params[k] for k in param_keys]))
    correlations = []
    for pv in param_values:
        params_dict = dict(zip(param_keys, pv))
        clf = model(**params_dict)
        clf.fit(xs_train, ys_train)
        ys_hat_valid = clf.predict(xs_valid)
        ys_hat_train = clf.predict(xs_train)
        ys_hat_test_chosen = clf.predict(xs_chosen_test)
        ys_hat_test_rejected = clf.predict(xs_rejected_test)
        mse_valid = np.mean((ys_hat_valid - ys_valid) ** 2)
        mse_train = np.mean((ys_hat_train - ys_train) ** 2)
        mse_test_chosen = np.mean((ys_hat_test_chosen - score_chosen_test) ** 2)
        mse_test_rejected = np.mean((ys_hat_test_rejected - score_rejected_test) ** 2)
        mse_test = (mse_test_chosen + mse_test_rejected) / 2
        valid_ratings["ys_hat"] = ys_hat_valid
        ys_hat_test = calc_ys_hat(
            clf,
            data_test,
        )
        roc_auc, _, _ = calc_and_plot_roc(ys_test, ys_hat_test, plot=False)
        valid_ratings_gb = (
            valid_ratings.groupby("created_at_month")
            .agg({"ys_hat": "mean"})
            .reset_index()
        )
        merged = pd.merge(
            train_ratings_gb,
            valid_ratings_gb,
            on="created_at_month",
            suffixes=("_train", "_valid"),
        )
        run_info[pv] = {
            "model": clf,
            "mse_valid": mse_valid,
            "mse_train": mse_train,
            "mse_test": mse_test,
            "roc_auc": roc_auc,
            "coef": clf.coef_,
            "intercept": clf.intercept_,
            "n_features": (np.abs(clf.coef_) > 0).sum(),
            "correlation": merged["ctr"].corr(merged["ys_hat"]),
        }

    alphas = list(run_info.keys())
    mse_trains = [run_info[alpha]["mse_train"] for alpha in alphas]
    mse_valids = [run_info[alpha]["mse_valid"] for alpha in alphas]
    mse_tests = [run_info[alpha]["mse_test"] for alpha in alphas]
    correlations = [run_info[alpha]["correlation"] for alpha in alphas]
    roc_aucs = [run_info[alpha]["roc_auc"] for alpha in alphas]

    best_correlation_idx = np.argmax(correlations)
    best_roc_idx = np.argmax(roc_aucs)
    mse_valid_idx = np.argmin(mse_valids)
    best_mse_test_idx = np.argmin(mse_tests)
    other_info = {
        "alpha": alphas[best_correlation_idx],
        "correlation": correlations[best_correlation_idx],
        "correlation_at_best": correlations[mse_valid_idx],
        "correlation_at_best_roc": correlations[best_roc_idx],
        "test_mse_alpha": alphas[best_mse_test_idx],
        "roc_auc_alpha": alphas[best_roc_idx],
        "roc_auc_best": roc_aucs[best_roc_idx],
        "roc_auc_at_best": roc_aucs[mse_valid_idx],
    }
    print(alphas[mse_valid_idx], mse_valids[mse_valid_idx], roc_aucs[mse_valid_idx])
    print(alphas[best_roc_idx], mse_valids[best_roc_idx], roc_aucs[best_roc_idx])
    print(
        alphas[best_mse_test_idx],
        mse_valids[best_mse_test_idx],
        roc_aucs[best_mse_test_idx],
    )

    if plot:
        # plt.title("MSE Train vs. alpha")
        plt.plot(alphas, mse_trains)
        plt.scatter(alphas, mse_trains, label="Training MSE loss")
        plt.xscale("log")
        plt.xlabel(r"$\lambda$", fontsize=16)
        plt.ylabel("Training MSE loss", fontsize=14)
        plt.axvline(
            x=alphas[mse_valid_idx],
            color="red",
            linestyle="--",
            label=r"Best $\lambda$: Valid MSE",
        )
        plt.axvline(
            x=alphas[best_roc_idx],
            color="black",
            linestyle="--",
            label=r"Best $\lambda$: Test ROC AUC",
        )
        plt.legend(fontsize=12)
        plt.grid()
        plt.savefig(
            f"../plots/mse_train_vs_alpha_{model_name}.pdf",
            dpi=600,
            bbox_inches="tight",
        )
        plt.show()
        # plt.title("MSE Valid vs. alpha")
        plt.plot(alphas, mse_valids)
        plt.scatter(alphas, mse_valids, label="Validation MSE loss")
        plt.xscale("log")
        plt.xlabel(r"$\lambda$", fontsize=16)
        plt.ylabel("Validation MSE loss", fontsize=14)
        plt.axvline(
            x=alphas[mse_valid_idx],
            color="red",
            linestyle="--",
            label=r"Best $\lambda$: Valid MSE",
        )
        plt.axvline(
            x=alphas[best_roc_idx],
            color="black",
            linestyle="--",
            label=r"Best $\lambda$: Test ROC AUC",
        )
        plt.legend(fontsize=12)
        plt.grid()
        plt.savefig(
            f"../plots/mse_valid_vs_alpha_{model_name}.pdf",
            dpi=600,
            bbox_inches="tight",
        )
        plt.show()

        plt.plot(alphas, mse_tests)
        plt.scatter(alphas, mse_tests, label="Test MSE loss")
        plt.xscale("log")
        plt.xlabel(r"$\lambda$", fontsize=16)
        plt.ylabel("Test MSE loss", fontsize=14)
        plt.axvline(
            x=alphas[mse_valid_idx],
            color="red",
            linestyle="--",
            label=r"Best $\lambda$: Valid MSE",
        )
        plt.axvline(
            x=alphas[best_roc_idx],
            color="black",
            linestyle="--",
            label=r"Best $\lambda$: Test ROC AUC",
        )
        plt.legend(fontsize=12)
        plt.grid()
        plt.savefig(
            f"../plots/mse_test_vs_alpha_{model_name}.pdf", dpi=600, bbox_inches="tight"
        )
        plt.show()

        plt.plot(alphas, roc_aucs)
        plt.scatter(alphas, roc_aucs, label="Test ROC AUC")
        plt.xscale("log")
        plt.xlabel(r"$\lambda$", fontsize=16)
        plt.ylabel("Test ROC AUC", fontsize=14)
        plt.axvline(
            x=alphas[mse_valid_idx],
            color="red",
            linestyle="--",
            label=r"Best $\lambda$: Valid MSE",
        )
        plt.axvline(
            x=alphas[best_roc_idx],
            color="black",
            linestyle="--",
            label=r"Best $\lambda$: Test ROC AUC",
        )
        plt.grid()
        plt.legend(fontsize=12)
        plt.savefig(
            f"../plots/roc_auc_test_vs_alpha_{model_name}.pdf",
            dpi=600,
            bbox_inches="tight",
        )
        plt.show()

        plt.plot(alphas, correlations)
        plt.scatter(
            alphas,
            correlations,
            label="Monthly correlation",
        )
        plt.xlabel(r"$\lambda$", fontsize=16)
        plt.ylabel("Monthly correlation", fontsize=14)
        plt.xscale("log")
        plt.axvline(
            x=alphas[mse_valid_idx],
            color="red",
            linestyle="--",
            label=r"Best Valid MSE $\lambda$",
        )
        plt.axvline(
            x=alphas[best_roc_idx],
            color="black",
            linestyle="--",
            label=r"Best ROC AUC $\lambda$",
        )
        plt.legend(fontsize=12)
        plt.grid()
        plt.savefig(
            f"../plots/correlation_vs_alpha_{model_name}.pdf",
            dpi=600,
            bbox_inches="tight",
        )
        plt.show()
    return (
        alphas[mse_valid_idx],
        mse_valids[mse_valid_idx],
        run_info[alphas[mse_valid_idx]]["model"],
        other_info,
    )


def calc_ys_hat(best_model, data_test):
    (
        xs_chosen_test,
        xs_rejected_test,
        score_chosen_test,
        score_rejected_test,
        ys_test,
    ) = data_test
    ys_hat_chosen = best_model.predict(xs_chosen_test)
    ys_hat_rejected = best_model.predict(xs_rejected_test)
    ys_hat_test = []
    for i in range(ys_hat_chosen.shape[0]):
        if ys_test[i] == 1:
            ys_hat_test.append(ys_hat_chosen[i] - ys_hat_rejected[i])
        else:
            ys_hat_test.append(ys_hat_rejected[i] - ys_hat_chosen[i])
    return np.array(ys_hat_test)


def calc_and_plot_roc(ys_test, ys_hat_test, model_name="", plot=True):
    fpr, tpr, _ = roc_curve(ys_test, ys_hat_test)
    roc_auc = auc(fpr, tpr)
    if plot:
        plt.plot(fpr, tpr, color="blue", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
        plt.plot([0, 1], [0, 1], color="gray", linestyle="--")  # Random classifier line
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.title("Receiver Operating Characteristic (ROC) Curve")
        plt.legend(loc="lower right")
        plt.savefig(
            f"../plots/roc_curve_{model_name}.pdf", dpi=600, bbox_inches="tight"
        )
        plt.show()
    return roc_auc, fpr, tpr


def get_data_paired(config_paired, train_size=None):
    train_path = config_paired["embedding"]["save_embeddings_path"].format(
        part="train", **config_paired
    )
    validation_path = config_paired["embedding"]["save_embeddings_path"].format(
        part="validation", **config_paired
    )
    test_path = config_paired["embedding"]["save_embeddings_path"].format(
        part="test", **config_paired
    )

    xs_chosen_train = torch.load(train_path + "_xs_chosen.pt")
    xs_rejected_train = torch.load(train_path + "_xs_rejected.pt")
    xs_chosen_valid = torch.load(validation_path + "_xs_chosen.pt")
    xs_rejected_valid = torch.load(validation_path + "_xs_rejected.pt")
    xs_chosen_test = torch.load(test_path + "_xs_chosen.pt")
    xs_rejected_test = torch.load(test_path + "_xs_rejected.pt")

    np.random.seed(13)
    xs_train = []
    ys_train = []
    train_size = xs_chosen_train.shape[0] if train_size is None else train_size
    for i in range(train_size):
        rnd = np.random.rand() < 0.5
        if rnd:
            xs_train.append(xs_chosen_train[i] - xs_rejected_train[i])
            ys_train.append(1)
        else:
            xs_train.append(xs_rejected_train[i] - xs_chosen_train[i])
            ys_train.append(0)
    xs_train = torch.stack(xs_train)
    ys_train = torch.tensor(ys_train)

    xs_valid = []
    ys_valid = []
    for i in range(xs_chosen_valid.shape[0]):
        xs_valid.append(xs_chosen_valid[i] - xs_rejected_valid[i])
        ys_valid.append(1)

    xs_valid = torch.stack(xs_valid)
    ys_valid = torch.tensor(ys_valid)

    np.random.seed(42)
    xs_test = []
    ys_test = []
    for i in range(xs_chosen_test.shape[0]):
        rnd = np.random.rand() < 0.5
        if rnd:
            xs_test.append(xs_chosen_test[i] - xs_rejected_test[i])
            ys_test.append(1)
        else:
            xs_test.append(xs_rejected_test[i] - xs_chosen_test[i])
            ys_test.append(0)
    xs_test = torch.stack(xs_test)
    ys_test = torch.tensor(ys_test)

    scaler_x = preprocessing.StandardScaler().fit(xs_train.float().numpy())
    xs_train = scaler_x.transform(xs_train.float().numpy())
    xs_valid = scaler_x.transform(xs_valid.float().numpy())
    xs_test = scaler_x.transform(xs_test.float().numpy())

    return (
        (xs_train, ys_train),
        (xs_valid, ys_valid),
        (xs_test, ys_test),
    )


def get_data_valid_causal(config, config_paired):
    train_path = config["embedding"]["save_embeddings_path"].format(
        part="train", **config
    )
    validation_path = config["embedding"]["save_embeddings_path"].format(
        part="validation", **config
    )
    ratings = load_from_disk(config["data_path"].format(**config))

    train_ratings = ratings["reward_one"].to_pandas()
    train_ratings["created_at_month"] = train_ratings["created_at"].apply(
        lambda x: x[:7]
    )

    valid_ratings = ratings["reward_valid_observational"].to_pandas()
    valid_ratings["created_at_month"] = valid_ratings["created_at"].apply(
        lambda x: x[:7]
    )

    xs_train = torch.load(train_path + "_xs.pt")
    ys_train = torch.load(train_path + "_ys.pt")
    xs_valid = torch.load(validation_path + "_xs.pt")
    ys_valid = torch.load(validation_path + "_ys.pt")

    scaler_x = preprocessing.StandardScaler().fit(xs_train.float().numpy())
    scaler_y = preprocessing.StandardScaler().fit(
        ys_train.float().numpy().reshape(-1, 1)
    )
    xs_train = scaler_x.transform(xs_train.float().numpy())
    ys_train = scaler_y.transform(ys_train.float().numpy().reshape(-1, 1))
    xs_valid = scaler_x.transform(xs_valid.float().numpy())
    ys_valid = scaler_y.transform(ys_valid.float().numpy().reshape(-1, 1))

    causal_valid_path = config_paired["embedding"]["save_embeddings_path"].format(
        part="validation_causal", **config_paired
    )
    xs_chosen_causal_valid = torch.load(causal_valid_path + "_xs_chosen.pt")
    xs_rejected_causal_valid = torch.load(causal_valid_path + "_xs_rejected.pt")
    score_chosen_causal_valid = (
        torch.load(causal_valid_path + "_score_chosen.pt")
        .float()
        .numpy()
        .reshape(-1, 1)
    )
    score_chosen_causal_valid = scaler_y.transform(score_chosen_causal_valid)
    score_rejected_causal_valid = (
        torch.load(causal_valid_path + "_score_rejected.pt")
        .float()
        .numpy()
        .reshape(-1, 1)
    )
    score_rejected_causal_valid = scaler_y.transform(score_rejected_causal_valid)

    chosen_causal_valid = scaler_x.transform(xs_chosen_causal_valid.float().numpy())
    rejected_causal_valid = scaler_x.transform(xs_rejected_causal_valid.float().numpy())

    np.random.seed(42)
    ys_causal_valid = []
    for i in range(chosen_causal_valid.shape[0]):
        rnd = np.random.rand() < 0.5
        if rnd:
            ys_causal_valid.append(1)
        else:
            ys_causal_valid.append(0)
    ys_causal_valid = np.array(ys_causal_valid)

    test_path = config_paired["embedding"]["save_embeddings_path"].format(
        part="test", **config_paired
    )
    xs_chosen_test = torch.load(test_path + "_xs_chosen.pt")
    xs_rejected_test = torch.load(test_path + "_xs_rejected.pt")
    score_chosen_test = (
        torch.load(test_path + "_score_chosen.pt").float().numpy().reshape(-1, 1)
    )
    score_chosen_test = scaler_y.transform(score_chosen_test)
    score_rejected_test = (
        torch.load(test_path + "_score_rejected.pt").float().numpy().reshape(-1, 1)
    )
    score_rejected_test = scaler_y.transform(score_rejected_test)

    chosen_test = scaler_x.transform(xs_chosen_test.float().numpy())
    rejected_test = scaler_x.transform(xs_rejected_test.float().numpy())

    np.random.seed(42)
    ys_test = []
    for i in range(chosen_test.shape[0]):
        rnd = np.random.rand() < 0.5
        if rnd:
            ys_test.append(1)
        else:
            ys_test.append(0)
    ys_test = np.array(ys_test)

    return (
        (xs_train, ys_train, train_ratings),
        (xs_valid, ys_valid, valid_ratings),
        (
            chosen_causal_valid,
            rejected_causal_valid,
            score_chosen_causal_valid,
            score_rejected_causal_valid,
            ys_causal_valid,
        ),
        (chosen_test, rejected_test, score_chosen_test, score_rejected_test, ys_test),
    )


def fit_and_finetune_valid_causal(
    data_train,
    data_valid_observational,
    data_valid_causal,
    data_test,
    model,
    params,
    plot=True,
    model_name="",
):
    xs_train, ys_train, train_ratings = data_train
    xs_valid, ys_valid, valid_ratings = data_valid_observational
    (
        xs_chosen_test,
        xs_rejected_test,
        score_chosen_test,
        score_rejected_test,
        ys_test,
    ) = data_test
    (
        xs_chosen_valid_causal,
        xs_rejected_valid_causal,
        score_chosen_valid_causal,
        score_rejected_valid_causal,
        ys_valid_causal,
    ) = data_valid_causal
    run_info = {}
    train_ratings_gb = (
        train_ratings.groupby("created_at_month").agg({"ctr": "mean"}).reset_index()
    )
    param_keys = list(params.keys())
    param_values = list(product(*[params[k] for k in param_keys]))
    correlations = []
    for pv in param_values:
        params_dict = dict(zip(param_keys, pv))
        clf = model(**params_dict)
        clf.fit(xs_train, ys_train)
        ys_hat_train = clf.predict(xs_train)

        ys_hat_valid = clf.predict(xs_valid)
        ys_hat_valid_causal_chosen = clf.predict(xs_chosen_valid_causal)
        ys_hat_valid_causal_rejected = clf.predict(xs_rejected_valid_causal)

        ys_hat_test_chosen = clf.predict(xs_chosen_test)
        ys_hat_test_rejected = clf.predict(xs_rejected_test)
        mse_train = np.mean((ys_hat_train - ys_train) ** 2)
        mse_valid_causal_chosen = np.mean(
            (ys_hat_valid_causal_chosen - score_chosen_valid_causal) ** 2
        )
        mse_valid_causal_rejected = np.mean(
            (ys_hat_valid_causal_rejected - score_rejected_valid_causal) ** 2
        )
        mse_valid_causal = (mse_valid_causal_chosen + mse_valid_causal_rejected) / 2
        mse_test_chosen = np.mean((ys_hat_test_chosen - score_chosen_test) ** 2)
        mse_test_rejected = np.mean((ys_hat_test_rejected - score_rejected_test) ** 2)
        mse_test = (mse_test_chosen + mse_test_rejected) / 2
        ys_hat_valid_causal = calc_ys_hat(
            clf,
            data_valid_causal,
        )
        ys_hat_test = calc_ys_hat(
            clf,
            data_test,
        )
        roc_auc_valid_causal, _, _ = calc_and_plot_roc(
            ys_valid_causal, ys_hat_valid_causal, plot=False
        )
        roc_auc, _, _ = calc_and_plot_roc(ys_test, ys_hat_test, plot=False)
        valid_ratings["ys_hat"] = ys_hat_valid
        valid_ratings_gb = (
            valid_ratings.groupby("created_at_month")
            .agg({"ys_hat": "mean"})
            .reset_index()
        )
        merged = pd.merge(
            train_ratings_gb,
            valid_ratings_gb,
            on="created_at_month",
            suffixes=("_train", "_valid"),
        )
        run_info[pv] = {
            "model": clf,
            "mse_valid": mse_valid_causal,
            "mse_train": mse_train,
            "mse_test": mse_test,
            "roc_auc": roc_auc,
            "roc_auc_valid_causal": roc_auc_valid_causal,
            "coef": clf.coef_,
            "intercept": clf.intercept_,
            "n_features": (np.abs(clf.coef_) > 0).sum(),
            "correlation": merged["ctr"].corr(merged["ys_hat"]),
        }

    alphas = list(run_info.keys())
    mse_trains = [run_info[alpha]["mse_train"] for alpha in alphas]
    mse_valids = [run_info[alpha]["mse_valid"] for alpha in alphas]
    mse_tests = [run_info[alpha]["mse_test"] for alpha in alphas]
    correlations = [run_info[alpha]["correlation"] for alpha in alphas]
    roc_aucs = [run_info[alpha]["roc_auc"] for alpha in alphas]
    roc_aucs_valid_causal = [
        run_info[alpha]["roc_auc_valid_causal"] for alpha in alphas
    ]

    best_correlation_idx = np.argmax(correlations)
    best_roc_idx = np.argmax(roc_aucs)
    roc_valid_idx = np.argmax(roc_aucs_valid_causal)
    best_mse_test_idx = np.argmin(mse_tests)
    other_info = {
        "alpha": alphas[best_correlation_idx],
        "correlation": correlations[best_correlation_idx],
        "correlation_at_best": correlations[roc_valid_idx],
        "correlation_at_best_roc": correlations[best_roc_idx],
        "test_mse_alpha": alphas[best_mse_test_idx],
        "roc_auc_alpha": alphas[best_roc_idx],
        "roc_auc_best": roc_aucs[best_roc_idx],
        "roc_auc_at_best": roc_aucs[roc_valid_idx],
    }
    print(alphas[roc_valid_idx], mse_valids[roc_valid_idx], roc_aucs[roc_valid_idx])
    print(alphas[best_roc_idx], mse_valids[best_roc_idx], roc_aucs[best_roc_idx])
    print(
        alphas[best_mse_test_idx],
        mse_valids[best_mse_test_idx],
        roc_aucs[best_mse_test_idx],
    )

    if plot:
        # plt.title("MSE Train vs. alpha")
        plt.plot(alphas, mse_trains)
        plt.scatter(alphas, mse_trains, label="Training MSE loss")
        plt.xscale("log")
        plt.xlabel(r"$\lambda$", fontsize=16)
        plt.ylabel("Training MSE loss", fontsize=14)
        plt.axvline(
            x=alphas[roc_valid_idx],
            color="red",
            linestyle="--",
            label=r"Best $\lambda$: Valid ROC",
        )
        plt.axvline(
            x=alphas[best_roc_idx],
            color="black",
            linestyle="--",
            label=r"Best $\lambda$: Test ROC AUC",
        )
        plt.legend(fontsize=12)
        plt.grid()
        plt.savefig(
            f"../plots/mse_train_vs_alpha_{model_name}.pdf",
            dpi=600,
            bbox_inches="tight",
        )
        plt.show()
        # plt.title("MSE Valid vs. alpha")
        plt.plot(alphas, mse_valids)
        plt.scatter(alphas, mse_valids, label="Validation MSE loss")
        plt.xscale("log")
        plt.xlabel(r"$\lambda$", fontsize=16)
        plt.ylabel("Validation MSE loss", fontsize=14)
        plt.axvline(
            x=alphas[roc_valid_idx],
            color="red",
            linestyle="--",
            label=r"Best $\lambda$: Valid ROC",
        )
        plt.axvline(
            x=alphas[best_roc_idx],
            color="black",
            linestyle="--",
            label=r"Best $\lambda$: Test ROC AUC",
        )
        plt.legend(fontsize=12)
        plt.grid()
        plt.savefig(
            f"../plots/mse_valid_vs_alpha_{model_name}.pdf",
            dpi=600,
            bbox_inches="tight",
        )
        plt.show()

        plt.plot(alphas, mse_tests)
        plt.scatter(alphas, mse_tests, label="Test MSE loss")
        plt.xscale("log")
        plt.xlabel(r"$\lambda$", fontsize=16)
        plt.ylabel("Test MSE loss", fontsize=14)
        plt.axvline(
            x=alphas[roc_valid_idx],
            color="red",
            linestyle="--",
            label=r"Best $\lambda$: Valid ROC",
        )
        plt.axvline(
            x=alphas[best_roc_idx],
            color="black",
            linestyle="--",
            label=r"Best $\lambda$: Test ROC AUC",
        )
        plt.legend(fontsize=12)
        plt.grid()
        plt.savefig(
            f"../plots/mse_test_vs_alpha_{model_name}.pdf", dpi=600, bbox_inches="tight"
        )
        plt.show()

        plt.plot(alphas, roc_aucs_valid_causal)
        plt.scatter(alphas, roc_aucs_valid_causal, label="Test ROC AUC")
        plt.xscale("log")
        plt.xlabel(r"$\lambda$", fontsize=16)
        plt.ylabel("Valid ROC AUC", fontsize=14)
        plt.axvline(
            x=alphas[roc_valid_idx],
            color="red",
            linestyle="--",
            label=r"Best $\lambda$: Valid ROC",
        )
        plt.axvline(
            x=alphas[best_roc_idx],
            color="black",
            linestyle="--",
            label=r"Best $\lambda$: Test ROC AUC",
        )
        plt.grid()
        plt.legend(fontsize=12)
        plt.savefig(
            f"../plots/roc_auc_valid_vs_alpha_{model_name}.pdf",
            dpi=600,
            bbox_inches="tight",
        )
        plt.show()

        plt.plot(alphas, roc_aucs)
        plt.scatter(alphas, roc_aucs, label="Test ROC AUC")
        plt.xscale("log")
        plt.xlabel(r"$\lambda$", fontsize=16)
        plt.ylabel("Test ROC AUC", fontsize=14)
        plt.axvline(
            x=alphas[roc_valid_idx],
            color="red",
            linestyle="--",
            label=r"Best $\lambda$: Valid ROC",
        )
        plt.axvline(
            x=alphas[best_roc_idx],
            color="black",
            linestyle="--",
            label=r"Best $\lambda$: Test ROC AUC",
        )
        plt.grid()
        # plt.legend(fontsize=12)
        plt.savefig(
            f"../plots/roc_auc_test_vs_alpha_{model_name}.pdf",
            dpi=600,
            bbox_inches="tight",
        )
        plt.show()

        plt.plot(alphas, correlations)
        plt.scatter(
            alphas,
            correlations,
            label="Monthly correlation",
        )
        plt.xlabel(r"$\lambda$", fontsize=16)
        plt.ylabel("Monthly correlation", fontsize=14)
        plt.xscale("log")
        plt.axvline(
            x=alphas[roc_valid_idx],
            color="red",
            linestyle="--",
            label=r"Best Valid ROC $\lambda$",
        )
        plt.axvline(
            x=alphas[best_roc_idx],
            color="black",
            linestyle="--",
            label=r"Best ROC AUC $\lambda$",
        )
        plt.legend(fontsize=12)
        plt.grid()
        plt.savefig(
            f"../plots/correlation_vs_alpha_{model_name}.pdf",
            dpi=600,
            bbox_inches="tight",
        )
        plt.show()
    return (
        alphas[roc_valid_idx],
        mse_valids[roc_valid_idx],
        run_info[alphas[roc_valid_idx]]["model"],
        other_info,
    )


def fit_and_finetune_paired(
    data_train, data_valid, model, params, plot=True, model_name=""
):
    xs_train, ys_train = data_train
    xs_valid, ys_valid = data_valid
    run_info = {}
    param_keys = list(params.keys())
    param_values = list(product(*[params[k] for k in param_keys]))
    for pv in param_values:
        params_dict = dict(zip(param_keys, pv))
        clf = model(**params_dict)
        clf.fit(xs_train, ys_train)
        score_valid = clf.score(xs_valid, ys_valid)
        score_train = clf.score(xs_train, ys_train)
        run_info[pv] = {
            "model": clf,
            "score_valid": score_valid,
            "score_train": score_train,
        }
    Cs = list(run_info.keys())
    score_trains = [run_info[C]["score_train"] for C in Cs]
    score_valids = [run_info[C]["score_valid"] for C in Cs]

    score_valid_idx = np.argmax(score_valids)
    print(Cs[score_valid_idx], np.min(score_valids))

    if plot:
        plt.title("Score Train vs. C")
        plt.plot(Cs, score_trains)
        plt.scatter(Cs, score_trains)
        plt.xscale("log")
        plt.savefig(
            f"../plots/score_train_vs_C_{model_name}.pdf", dpi=600, bbox_inches="tight"
        )
        plt.show()
        plt.title("Score Validation vs. C")
        plt.plot(Cs, score_valids)
        plt.scatter(Cs, score_valids)
        plt.xscale("log")
        plt.savefig(
            f"../plots/score_valid_vs_C_{model_name}.pdf", dpi=600, bbox_inches="tight"
        )
        plt.show()
    score_valid_idx = np.argmax(score_valids)
    print(Cs[score_valid_idx], score_valids[score_valid_idx])
    return (
        Cs[score_valid_idx],
        score_valids[score_valid_idx],
        run_info[Cs[score_valid_idx]]["model"],
    )


def calc_ys_hat_paired(best_model, data_test):
    xs_test, _ = data_test
    return best_model.predict_proba(xs_test)[:, 1]
