import pandas as pd
import numpy as np
import zarr
from tqdm import tqdm

from sklearn.cross_decomposition import PLSRegression
from sklearn.model_selection import KFold, train_test_split, GroupKFold
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.metrics import r2_score, accuracy_score
from matplotlib import pyplot as plt
import matplotlib.cm as cm
import umap

import os
from modules import constants

summary_text = pd.read_csv(
    os.path.join(constants.DATA_DIR, "interim/sae_activations/normal/summary-gemma-2b-it-FIXED.csv"),
    index_col=0,
)

layer_results = {}
r2_logit_bootstrapped = {}
for layer in [10]:   
    activation_files = os.listdir(
        os.path.join(constants.DATA_DIR, "interim/sae_activations/normal/gemma-2b", f"layer_{layer}")
    )
    all_activations = []
    hallucination_scores = []
    index_num = []
    for file in tqdm(activation_files):
        if file.endswith(".zarr"):
            index = int(file.split(".")[0].split("_")[-1])
            activations = zarr.open(
                os.path.join(constants.DATA_DIR, "interim/sae_activations/normal/gemma-2b", f"layer_{layer}", file),
                mode="r",
            )
            activations = activations[:]
            # all_activations.append(activations[1:,:]) # Skip BOS token
            all_activations.append(activations[1:,:].max(axis=0)  / activations[1:,:].shape[0])

            hall_score = summary_text.iloc[index].hall_score
            # hallucination_scores.append([hall_score]*activations[1:,:].shape[0])
            hallucination_scores.append(hall_score.item())

            index_num.append([index]*activations[1:,:].shape[0])
    all_activations = np.vstack(all_activations)
    hallucination_scores = np.hstack(hallucination_scores)
    index_num = np.concatenate(index_num, axis=0)

    for strap in range(10):

        kf = KFold(n_splits=10, shuffle=True)
    
        X = all_activations
        y = hallucination_scores
        y_logit = hallucination_scores
        y_bin = np.where(hallucination_scores > np.median(hallucination_scores), 1, 0)
        groups = index_num

        accs = []
        logit_preds = []
        logit_trues = []
        logit_preds_train = []
        logit_trues_train = []
        bin_preds = []
        bin_trues = []
        r2s_mean = []
        for train_index, test_index in kf.split(all_activations, hallucination_scores):

            pls_logit = PLSRegression(n_components=4)
            pls_logit.fit(X[train_index], y_logit[train_index])
            X_train_score = pls_logit.transform(X[train_index])
            X_test_score = pls_logit.transform(X[test_index])
            # print(f"R2 logit: {pls_logit.score(X[test_index], y_logit[test_index])}")
            logit_pred = pls_logit.predict(X[test_index])
            logit_pred = np.clip(logit_pred, 0, 1)
            logit_preds.append(logit_pred)
            logit_trues.append(y_logit[test_index])
            logit_pred_train = pls_logit.predict(X[train_index])
            logit_pred_train = np.clip(logit_pred_train, 0, 1)
            logit_preds_train.append(logit_pred_train)
            logit_trues_train.append(y_logit[train_index])
            r2s_mean.append(r2_score(y_logit[test_index], logit_pred))

            clf = LogisticRegression(max_iter=1000)
            clf.fit(X_train_score, y_bin[train_index])
            # print(f"Acc: {clf.score(X_test_score, y_bin[test_index])}\n")
            accs.append(clf.score(X_test_score, y_bin[test_index]))
            bin_pred = clf.predict(X_test_score)
            bin_preds.append(bin_pred)
            bin_true = y_bin[test_index]
            bin_trues.append(bin_true)

        print(f"Mean acc: {np.mean(accs)}")
        print(f"Full R2 logit: {r2_score(np.hstack(logit_trues), np.hstack(logit_preds))}")
        print(f"Full R2 logit train: {r2_score(np.hstack(logit_trues_train), np.hstack(logit_preds_train))}")
        print(f"Full acc: {accuracy_score(np.hstack(bin_trues), np.hstack(bin_preds))}")
        print(f"Mean R2 logit: {np.mean(r2s_mean)}")

        layer_results[layer] = {
            "acc": np.mean(accs),
            "acc_stacked": accuracy_score(np.hstack(bin_trues), np.hstack(bin_preds)),
            "r2_logit": r2_score(np.hstack(logit_trues), np.hstack(logit_preds)),
            "r2_logit_train": r2_score(np.hstack(logit_trues_train), np.hstack(logit_preds_train)),
            "r2_logit_mean": np.mean(r2s_mean),
        }
        if strap ==0:
            r2_logit_bootstrapped[layer] = [r2_score(np.hstack(logit_trues), np.hstack(logit_preds))]
        r2_logit_bootstrapped[layer].append(r2_score(np.hstack(logit_trues), np.hstack(logit_preds)))

results = pd.DataFrame(layer_results).T
results.drop(columns=["r2_logit_mean", "acc", "acc_stacked", "r2_logit_train"], inplace=True)
results["r2_std"] = [np.std(r2_logit_bootstrapped[key]) for key in r2_logit_bootstrapped.keys()]
# Set all negative values to 0
results["r2_logit"] = [r2_logit_bootstrapped[key][0] for key in r2_logit_bootstrapped.keys()]
results["r2_logit"] = results["r2_logit"].clip(lower=0)
results.index = results.index.astype(str)

# Bar plot of results
plt.figure(figsize=(10, 5))
# Get the darkest blue from the "Blues" colormap
dark_blue = plt.get_cmap("Blues")(0.6)

# Bar plot using the darkest blue
plt.figure(figsize=(6, 4))

for i, (x, y, err) in enumerate(zip(results.index, results["r2_logit"], results["r2_std"])):
    if y != 0:
        plt.bar(x, y, yerr=results["r2_std"][i], capsize=5, color=dark_blue, zorder=10)
        plt.text(i, y + err + 0.01, f"{y:.2f}", ha='center', va='bottom', fontsize=11.5)
    else:
        # Draw a stub line at y=0, same width as bar
        plt.plot([i - 0.4, i + 0.4], [0, 0], color=dark_blue, linewidth=5)
        plt.errorbar(i, 0, yerr=err, capsize=5, color='black', fmt='none', zorder=10)
        plt.text(i, err + 0.01, "0", ha='center', va='bottom')

x_ticks = ["1", "7", "11", "13", "18"]
plt.xticks(ticks=range(len(x_ticks)), labels=x_ticks, fontsize=13)
plt.yticks(fontsize=13)
plt.ylabel(r"Hallucination Prediction ($R^2$)", fontsize=11*1.3)
plt.xlabel("Gemma 2B-IT Layer Number", fontsize=11*1.3)


# Clean up spines
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Add faint grid lines
ax.yaxis.grid(True, linestyle='-', alpha=0.3, zorder=-5)

# Save plot in high resolution svg
plt.savefig(
    "hallucination_r2.svg",
    dpi=300,
    bbox_inches="tight",
    format="svg",
)
plt.show()

### VIP ###
def vip(pls):
    T, W, Q = pls_logit.x_scores_, pls_logit.x_weights_, pls_logit.y_loadings_
    p, a    = W.shape
    ssy     = np.sum((T**2) * (Q**2), axis=0)           # SSY_a
    total   = ssy.sum()
    vip = np.sqrt(p * (W**2 @ ssy) / total)
    return vip

vip_scores = vip(pls_logit)

top10_vip_scores = (np.sign(pls_logit.coef_.ravel())*vip_scores).argsort()[:10]
