
import numpy as np
import pandas as pd
from scipy.stats import pearsonr, spearmanr, stats
import subprocess
import matplotlib.pyplot as plt
import random
from typing import List, Tuple
from sklearn.linear_model import LinearRegression
import statsmodels.api as sm
import matplotlib.lines as mlines
import matplotlib as mpl

import json
import os
from sklearn.metrics import roc_curve, auc

mpl.rc('font', family='serif', serif=['STIXGeneral'])
def procrustes_2d(X, Y):
    """
    Find optimal 2D rotation matrix R such that Y ≈ R @ X

    Args:
        X: 2 x N matrix (source vectors as columns)
        Y: 2 x N matrix (target vectors as columns)

    Returns:
        theta: rotation angle in radians
    """
    # Compute cross-covariance matrix
    C = np.array(Y).T @ np.array(X)
    # print(f"{C=}")
    # Extract elements
    a, b, c, d = C[0, 0], C[0, 1], C[1, 0], C[1, 1]

    # Compute optimal rotation angle
    theta = np.arctan2(c - b, a + d)

    # Construct rotation matrix
    R = np.array([[np.cos(theta), -np.sin(theta)],
                  [np.sin(theta), np.cos(theta)]])

    return theta

def theta_per_example(X, Y):
    """
    Compute rotation angle theta for each pair of 2D vectors in X and Y

    Args:
        X: list of 2D vectors (source)
        Y: list of 2D vectors (target)

    Returns:
        thetas: list of rotation angles in radians
    """
    thetas = []
    for x, y in zip(X, Y):
        cos_theta = np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))
        thetas.append(np.arccos(np.clip(cos_theta, -1, 1)))
    assert len(thetas) == len(X), f"Mismatch in number of angles computed: {len(thetas)} vs {len(X)}"
    return thetas

def CCW_needed_check(v1, v2, eps=1e-12):
    """Determines the signed angle direction between two vectors"""
    v1 = np.asarray(v1, dtype=float)
    v2 = np.asarray(v2, dtype=float)
    if np.linalg.norm(v1) < eps or np.linalg.norm(v2) < eps:
        raise ValueError("zero-length vector")

    e1 = v1 / np.linalg.norm(v1)

    # pick a reference r not collinear with e1
    n = e1.size
    # try standard basis vectors until one works
    r = np.ones(n) / np.sqrt(n)


    # build e2 as r projected orthogonal to e1
    temp = r - np.dot(e1, r) * e1
    norm_temp = np.linalg.norm(temp)
    if norm_temp < eps:
        raise RuntimeError("failed to find reference orthogonal component")
    e2 = temp / norm_temp

    # coordinates of v2 in basis (e1, e2)
    x = np.dot(e1, v2)
    y = np.dot(e2, v2)

    theta = np.arctan2(y, x)  # signed angle in (-pi, pi]
    return theta


def plot_lambda(degrees_per_file,split_models=None, ordered ="False"):
    """Plots the exponential decay of the error term (Lambda)."""
    for s in split_models:
        lambdas = {}
        for file, obs in degrees_per_file.items():
            if not any([s_i in file for s_i in s]) or ordered not in file:
                continue
            lambda2 = (obs['2-2']["F-F"] + obs['2-2']["H-H"])-1
            lambdas[file.split("_")[2]+"_"+file.split("_")[3]] = lambda2
        # plot the lambdas
        plot_lambda_custom_convergence(lambdas, f'plots/lambda_convergence_{s[0]}_ordered_{ordered}.pdf')



def plot_lambda_custom_convergence(lambdas:dict,path):
    # Setup
    time_steps = np.arange(0, 20, 1)  # Look at first 500 steps

    plt.figure(figsize=(10, 6))
    plt.rcParams.update({'font.family': 'serif', 'font.size': 12})

    # get colors from the "Okabe–Ito" color-blind friendly palette
    colors = ['#0072B2', '#D55E00', '#009E73', '#CC79A7', '#56B4E9', '#F0E442', '#E69F00']

    # Distinct markers for accessibility
    markers = ['o', 's', '^', 'D', 'v', 'P', '*']


    for i, (name, lam) in enumerate(lambdas.items()):
        # Calculate the error decay: Error ~ lambda^t
        # We start from an initial error of 1.0 (normalized)
        decay_curve = lam ** time_steps

        plt.plot(time_steps, decay_curve,
                 label=name.replace("do_not", "Do-Not-Answer").replace("natural_100", "NaturalQA").replace("triviaqa_100", "TriviaQA").replace("sorry_100", "Sorry").replace("sycophancy_negative","S-neg").replace("sycophancy", "S-pos").replace("_100",""),
                 color=colors[i], marker=markers[i],
                 linewidth=2.5)

    plt.yscale('log')  # Crucial: Makes exponential decay look linear
    plt.xscale('log')
    plt.ylim(10**-10, 1)
    plt.xlim(1, 10)
    plt.xlabel('Time (Iterations)',fontsize=30)
    if "llama" in path:
        plt.ylabel('Distance to Stationarity',fontsize=27)
        plt.legend(fontsize=25, frameon=True)


    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.grid(True, which="both", ls="-", alpha=0.2)
    plt.tight_layout()

    plt.savefig(path, format='pdf', dpi=300)
    print(f"Saved lambda convergence plot to {path}")
    plt.close()

def plot_hidden_states_through_time(hidden_states_list: List[List[Tuple]], calculate_angle_per_point: bool = False):
    degrees = {}
    hidden_states_to_test_on = random.sample(range(len(hidden_states_list)), int(len(hidden_states_list) * 0.5))
    hidden_states_to_train_on = [i for i in range(len(hidden_states_list)) if i not in hidden_states_to_test_on]
    hall_mean_vector = np.mean([np.array(hidden_states_list[i][j][0]) for i in hidden_states_to_train_on for j in
                                range(len(hidden_states_list[i])) if hidden_states_list[i][j][1] == 1], axis=0)
    factual_mean_vector = np.mean([np.array(hidden_states_list[i][j][0]) for i in hidden_states_to_train_on for j in
                                   range(len(hidden_states_list[i])) if hidden_states_list[i][j][1] == 0], axis=0)
    hall_mean_vector = hall_mean_vector / np.linalg.norm(hall_mean_vector)
    factual_mean_vector = factual_mean_vector / np.linalg.norm(factual_mean_vector)
    alpha_to_h_f = []
    test_sequence = hidden_states_list[hidden_states_to_test_on[0]]
    # plot the mean h vector, the mean f vector and all the points to plot on a 2-d.
    all_test_examples = [hidden_states_list[i] for i in hidden_states_to_test_on]

    def plot_vectors_2d(points_to_plot, hall_mean_vector, factual_mean_vector, create_fig=False):
        import matplotlib.pyplot as plt
        import numpy as np

        # Normalize the mean vectors (already done in your code, but just to be sure)
        f_vec = factual_mean_vector / np.linalg.norm(factual_mean_vector)
        h_vec = hall_mean_vector / np.linalg.norm(hall_mean_vector)

        basis_1 = f_vec

        # Second basis vector: orthogonal to first, in the plane of both mean vectors
        basis_2 = h_vec - np.dot(h_vec, basis_1) * basis_1
        basis_2 = basis_2 / np.linalg.norm(basis_2)
        theta_with_sign = np.sign(CCW_needed_check(f_vec, h_vec))

        # Enforce convention: phenomena mean should have POSITIVE B2 component
        if theta_with_sign < 0:
            basis_2 = -basis_2

        # Project all state vectors onto the 2D plane
        points_2d = []
        colors = []
        labels = []

        for i, (state_vector, label) in enumerate(points_to_plot):
            # Normalize and project the state vector
            state_vec_normalized = np.array(state_vector)
            x = np.dot(state_vec_normalized, basis_1)
            y = np.dot(state_vec_normalized, basis_2)

            points_2d.append((x, y))
            colors.append('red' if label == 1 else 'blue')
            labels.append(i)
        if not create_fig:
            return points_2d


    # plot_vectors_2d(test_sequence, hall_mean_vector, factual_mean_vector,
    #                 True)  # plot figure on one example of sequence
    H_F_X = []
    H_F_Y = []
    F_F_X = []
    F_F_Y = []
    H_H_X = []
    H_H_Y = []
    F_H_X = []
    F_H_Y = []
    for test in all_test_examples:
        points_2d = plot_vectors_2d(test, hall_mean_vector, factual_mean_vector, False)
        # print(f"Processing test example with {len(points_2d)} points. each point is a 2d vector like {points_2d[0]}")
        # calculate the angle between each point and the hall and factual mean vector
        for i in range(len(points_2d) - 1):
            if test[i][1] == 1 and test[i + 1][1] == 0:
                H_F_X.append(points_2d[i])
                H_F_Y.append(points_2d[i + 1])
            elif test[i][1] == 0 and test[i + 1][1] == 0:
                F_F_X.append(points_2d[i])
                F_F_Y.append(points_2d[i + 1])
            elif test[i][1] == 1 and test[i + 1][1] == 1:
                H_H_X.append(points_2d[i])
                H_H_Y.append(points_2d[i + 1])
            elif test[i][1] == 0 and test[i + 1][1] == 1:
                F_H_X.append(points_2d[i])
                F_H_Y.append(points_2d[i + 1])
    if calculate_angle_per_point:
        degrees['H-F per point'] = np.degrees(theta_per_example(H_F_X, H_F_Y))
        degrees['F-F per point'] = np.degrees(theta_per_example(F_F_X, F_F_Y))
        degrees['H-H per point'] = np.degrees(theta_per_example(H_H_X, H_H_Y))
        degrees['F-H per point'] = np.degrees(theta_per_example(F_H_X, F_H_Y))
    theta_optimal = procrustes_2d(H_F_X, H_F_Y)
    degrees['H-F'] = np.degrees(float(theta_optimal))
    print(f"Optimal rotation for H-F: {np.degrees(theta_optimal)} degrees")
    theta_optimal = procrustes_2d(F_F_X, F_F_Y)
    degrees['F-F'] = np.degrees(float(theta_optimal))
    print(f"Optimal rotation for F-F: {np.degrees(theta_optimal)} degrees")
    theta_optimal = procrustes_2d(H_H_X, H_H_Y)
    degrees['H-H'] = np.degrees(float(theta_optimal))
    print(f"Optimal rotation for H-H: {np.degrees(theta_optimal)} degrees")
    theta_optimal = procrustes_2d(F_H_X, F_H_Y)
    degrees['F-H'] = np.degrees(float(theta_optimal))
    print(f"Optimal rotation for F-H: {np.degrees(theta_optimal)} degrees")
    basis_1 = factual_mean_vector

    # Second basis vector: orthogonal to first, in the plane of both mean vectors
    basis_2 = hall_mean_vector - np.dot(hall_mean_vector, basis_1) * basis_1
    basis_2 = basis_2 / np.linalg.norm(basis_2)

    # Project the mean vectors onto 2D
    f_2d = np.array([1, 0])  # factual is along x-axis
    h_2d = np.array([np.dot(hall_mean_vector, basis_1), np.dot(hall_mean_vector, basis_2)])
    theta_optimal = procrustes_2d(f_2d.reshape(1, 2), h_2d.reshape(1, 2))
    print(f"Degree Between H and F mean vectors: {np.degrees(theta_optimal)} degrees")
    degrees['Mean H-F'] = np.degrees(float(theta_optimal))
    return degrees




def plot_roc_curve_per_model(degrees_per_file, split_models=None, ordered="False"):
    """Plots the ROC curve for detection capabilities."""
    for s in split_models:
        switch_scores = []
        remain_scores = []

        # --- Data Aggregation ---
        for file, obs in degrees_per_file.items():
            if not any([s_i in file for s_i in s]) or ordered not in file:
                continue
            # Using absolute values as before
            h_f = [abs(np.mean([obs[j]["H-F per point"][i] for j in ["degrees_7"]])) for i in
                   range(len(obs["degrees_7"]["H-F per point"]))]
            f_h = [abs(np.mean([obs[j]["F-H per point"][i] for j in ["degrees_7"]])) for i in
                   range(len(obs["degrees_7"]["F-H per point"]))]
            h_h = [abs(np.mean([obs[j]["H-H per point"][i] for j in ["degrees_7"]])) for i in
                   range(len(obs["degrees_7"]["H-H per point"]))]
            f_f = [abs(np.mean([obs[j]["F-F per point"][i] for j in ["degrees_7"]])) for i in
                   range(len(obs["degrees_7"]["F-F per point"]))]

            switch_scores.extend(h_f)
            switch_scores.extend(f_h)
            remain_scores.extend(h_h)
            remain_scores.extend(f_f)

        if not switch_scores or not remain_scores:
            continue

        # 1. Create y_true (labels): 1 for Switch, 0 for Remain
        print(f"Number of Switch scores: {len(switch_scores)}, Number of Remain scores: {len(remain_scores)}")
        y_true = [1] * len(switch_scores) + [0] * len(remain_scores)

        # 2. Create y_scores: The raw angle values (no need to normalize manually)
        y_scores = switch_scores + remain_scores

        # 3. Calculate ROC automatically
        fpr, tpr, thresholds = roc_curve(y_true, y_scores, drop_intermediate=False)
        roc_auc = auc(fpr, tpr)  # Calculate Area Under Curve

        print(f"Model: {s} | AUC: {roc_auc:.3f}")
        print(f"Data range: Min={min(y_scores):.4f}, Max={max(y_scores):.4f}")

        # --- Plotting ---
        plt.figure(figsize=(10, 8))

        plt.plot(fpr, tpr, linewidth=3, label=f'ROC Curve (AUC = {roc_auc:.2f})')
        plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Random Chance')

        plt.xlabel('False Positive Rate (FPR)', fontsize=30)
        plt.ylabel('True Positive Rate (TPR)', fontsize=30)
        # plt.title(f'ROC: Angle vs Switch ({s[0]})', fontsize=25)

        plt.legend(fontsize=25, frameon=True)
        plt.xticks(fontsize=25)
        plt.yticks(fontsize=25)
        plt.grid(True, alpha=0.3)
        plt.xlim([-0.05, 1.05])
        plt.ylim([-0.05, 1.05])

        plt.tight_layout()
        plt.savefig(f'plots/roc_curve_{s[0]}_ordered_{ordered}.pdf', format='pdf', dpi=300)
        plt.close()



def compute_all_consistencies(observations,split_models=None, ordered ="False", degree = False, degrees_per_file=None, two_topics = False, title_addition=""):
    """Scatter plot correlating Geometric Consistency with Trace."""
    print(f"Computing consistencies for ordered={ordered} degree={degree} two_topics={two_topics}")
    c_geo_all = []
    c_sm_all = []
    for s in split_models:
        c_geos = []
        c_sms = []
        for file,obs in observations.items():
            if not any([s_i in file for s_i in s]) or ordered not in file:
                 continue
            c_geo_seeds = []
            for i in ["degrees_7", "degrees_21", "degrees_42"]:

                if degrees_per_file is None:
                    degrees = obs[i]
                else:
                    degrees = degrees_per_file[file][i]
                c_geo = (abs(degrees["H-F"]) + abs(degrees["F-H"])) / (2 * degrees["Mean H-F"])
                c_sm = (obs['2-2']["F-F"] + obs['2-2']["H-H"])
                # c_sm = obs['2-2']["H-H"]
                if degree:
                    c_geo_seeds.append(degrees["Mean H-F"])
                else:
                    c_geo_seeds.append(c_geo)
            c_sms.append(c_sm)
            c_geos.append(np.mean(c_geo_seeds))


        spearman_corr, sp_value = spearmanr(c_geos, c_sms)
        print(f"Results for model: {s} under condition ordered={ordered} degree={degree}")
        print("Number of observations:", len(c_geos))
        assert len(c_geos) ==6
        assert len(c_geos) == len(c_sms)
        # print("P-value:", round(sp_value,4))
        print("Spearman Correlation:", round(spearman_corr,4))
        # print("Pearson Correlation:", pearson_corr)
        c_geo_all.extend(c_geos)
        c_sm_all.extend(c_sms)
    spearman_corr, sp_value = spearmanr(c_geo_all, c_sm_all)
    print(f"The overall mean of the trace is {round(np.array(c_sm_all).mean(),4)}+-{round(np.array(c_sm_all).std(),4)}")
    print(f"The overall mean of the geometric consistency is {round(np.array(c_geo_all).mean(),4)}+-{round(np.array(c_geo_all).std(),4)}")
    print(f"Results for all models combined")

    print("P-value:", round(sp_value,4))
    print("Spearman Correlation:", round(spearman_corr,4))

    # Okabe–Ito color-blind friendly palette
    cb_colors = ["#0072B2", "#009E73", "#D55E00",
                 "#CC79A7", "#56B4E9", "#F0E442", "#E69F00"]
    dataset_ordered = ["natural_100", "triviaqa_100", "sorry_100", "do_not", "sycophancy_100", "sycophancy_negative"]
    marker_styles = ['o', 's', '^', 'D', 'v', 'P', '*']

    # Create Mappings
    dataset_to_marker = {ds: marker_styles[idx % len(marker_styles)] for idx, ds in enumerate(dataset_ordered)}

    plt.figure(figsize=(8, 6))

    all_x = []
    all_y = []

    # --- Plotting Loop ---
    for idx, s in enumerate(split_models):
        model_name = s[0]
        model_color = cb_colors[idx % len(cb_colors)]

        # Iterate through all observations
        for file, obs in observations.items():
            if model_name not in file or ordered not in file:
                continue

            # Extract dataset name
            try:
                current_dataset = file.split("_")[2] + "_" + file.split("_")[3]
            except IndexError:
                continue

            # Skip if dataset is not in our list
            if current_dataset not in dataset_to_marker:
                continue

            # --- Data Calculation ---
            if degrees_per_file is None:
                degrees = [obs["degrees_21"], obs["degrees_42"], obs["degrees_7"]]
            else:
                degrees = [degrees_per_file[file]["degrees_21"], degrees_per_file[file]["degrees_42"],
                           degrees_per_file[file]["degrees_7"]]

            if degree:
                geo_val = np.mean([d["Mean H-F"] for d in degrees])
            else:
                geo_val = np.mean([(abs(d["H-F"]) + abs(d["F-H"])) / (2 * d["Mean H-F"]) for d in degrees])

            sm_val = obs['2-2']["F-F"] + obs['2-2']["H-H"]

            # We plot inside the loop so every point gets the correct marker
            plt.scatter(
                geo_val, sm_val,
                s=80,
                alpha=0.8,
                color=model_color,
                marker=dataset_to_marker[current_dataset],  # Use the specific marker for this file
                label=None  # We handle the legend manually below
            )

            all_x.append(geo_val)
            all_y.append(sm_val)

    # --- Global Regression Line ---
    if len(all_x) > 1:
        X_all = np.array(all_x).reshape(-1, 1)
        Y_all = np.array(all_y)
        reg = LinearRegression().fit(X_all, Y_all)

        x_line = np.linspace(min(all_x), max(all_x), 200).reshape(-1, 1)
        y_line = reg.predict(x_line)

        plt.plot(x_line, y_line, linewidth=2.0, color="black", alpha=0.9)

    # --- Custom Legend Construction (Same as before) ---
    model_handles = []

    # 1. Add Model Entries (Colors)
    model_handles.append(mlines.Line2D([], [], color='none', label=r'$\bf{Models}$'))

    for idx, s in enumerate(split_models):
        label_text = s[0].replace("gpt", "GPT-OSS-20B") \
            .replace("llama", "Llama-3.1-8B") \
            .replace("Qwen", "Qwen-3-8B")
        color = cb_colors[idx % len(cb_colors)]
        handle = mlines.Line2D([], [], color=color, marker='_', linestyle='None',
                               markersize=10, markeredgewidth=6, label=label_text)
        model_handles.append(handle)
    dataset_handles = []
    # 2. Add Dataset Entries (Markers)
    dataset_handles.append(mlines.Line2D([], [], color='none', label=r'  '))
    dataset_handles.append(mlines.Line2D([], [], color='none', label=r'$\bf{Datasets}$'))

    for ds in dataset_ordered:
        marker = dataset_to_marker[ds]
        clean_label = ds.replace("do_not", "Do-Not-Answer").replace("natural_100", "NaturalQA").replace("triviaqa_100",
                                                                                                        "TriviaQA").replace(
            "sorry_100", "Sorry").replace("sycophancy", "Sycophancy").replace("_100", "").replace("NaturalQA","NaturalQA (Hallucination)").replace("TriviaQA","TriviaQA (Hallucination)").replace("Sorry","Sorry (Refusal)").replace("Sycophancy_negative","S-neg (Placeholder)").replace("Sycophancy","S-pos (Placeholder)").replace("Placeholder","Sycophancy").replace("Do-Not-Answer","Do-Not-Answer (Refusal)")
        handle = mlines.Line2D([], [], color='gray', marker=marker, linestyle='None',
                               markersize=10, label=clean_label)
        dataset_handles.append(handle)

    # Apply the custom legend
    # plt.legend(handles=dataset_handles, fontsize=12, frameon=True, loc='best', ncol=2)
    legend1 = plt.legend(handles=model_handles, loc='upper left', fontsize=12, frameon=True)
    plt.gca().add_artist(legend1)
    # Create the second legend (Datasets) at Bottom Right
    plt.legend(handles=dataset_handles, loc='lower right', fontsize=12, frameon=True)
    # --- Final Formatting ---
    plt.xlabel(r'$\theta_{ref}$' if degree else 'Geometric Consistency', fontsize=25)
    plt.ylabel('TR', fontsize=25)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    if "False" == ordered or two_topics:
        plt.ylim(0.8,1.8)
        if two_topics:
            plt.ylim(0.7,1.8)
        plt.xlim(0,80)
    plt.grid(True, linestyle='--', alpha=0.35)
    plt.tight_layout()

    plt.savefig(
        f'plots/consistency_scatter_ordered_{ordered}_degree_{degree}_85{"_two_topics" if two_topics else ""}{title_addition}.pdf', format='pdf', dpi=300
    )
    plt.close()

    return spearman_corr, sp_value











def steps_back_correlation(observations, max_steps_back=3):
    """Bar chart showing the effect of history length (k) on hallucination probability."""
    observations = {k:v for k,v in observations.items() if "True" in k}
    delta_results = {}
    models = set([ "_".join(file.split("_")[0:2]) for file in observations.keys()])
    significance_results = {}
    shapley_results = {}
    for file,obs in observations.items():
        if not any([model in file for model in models]) or "True" not in file:
             continue
        results = obs["hidden_states"]
        degree_middle_layers = []
        for r in results:
            inner_list = []
            for a in r:
                inner_list.append(a[1])
            degree_middle_layers.append(inner_list)
        file_deltas = calculate_hallucination_deltas(degree_middle_layers, max_steps_back)
        delta_results[file] = file_deltas
        file_shapley = calculate_shapley_likelihood(degree_middle_layers, max_steps_back)
        shapley_results[file] = file_shapley

    print(f"{shapley_results=}")
    print(f"average shapley across models and datasets for each k:")
    for k in range(1, max_steps_back + 1):
        all_shapley = [shapley_results[file][k] for file in shapley_results.keys()]
        avg_shapley = np.mean(all_shapley)
        std_shapley = np.std(all_shapley)
        print(f"k={k}: {avg_shapley:.4f} ± {std_shapley:.4f}")
    # calculate the shaply values for each model and dataset
    print(f"{delta_results=}")
    plot_steps_back_correlation(observations, delta_results, models, max_steps_back)
    plot_steps_back_correlation(observations, shapley_results, models, max_steps_back, title_addition="_full_options")

def plot_steps_back_correlation(observations, delta_results,models, max_steps_back=3, title_addition=""):
    dataset_ordered = ["natural_100", "triviaqa_100", "sorry_100", "do_not", "sycophancy_100", "sycophancy_negative"]
    color_blind_palette = ["#D55E00", "#E69F00", "#0072B2", "#56B4E9", "#009E73", "#CC79A7", "#F0E442"]
    plt.figure(figsize=(12, 8))

    n_datasets = len(dataset_ordered)
    total_bars = n_datasets + 1  # +1 for the Average bar
    bar_width = 0.8 / total_bars  # Total width of 0.8 divided among datasets
    x = np.arange(1, max_steps_back + 1)  # Base x positions
    all_deltas_per_step = {k: [] for k in range(1, max_steps_back + 1)}
    for i, ds in enumerate(dataset_ordered):
        deltas_per_step = {k: [] for k in range(1, max_steps_back + 1)}
        for model in models:
            start_key = f"{model}_{ds}"
            file_key = [file for file in observations.keys() if (file.startswith(start_key))][0]
            file_deltas = delta_results[file_key]
            for k in range(1, max_steps_back + 1):
                deltas_per_step[k].append(file_deltas[k])
                all_deltas_per_step[k].append(file_deltas[k])
        avg_deltas = [np.mean(deltas_per_step[k]) for k in range(1, max_steps_back + 1)]
        std_deltas = [np.std(deltas_per_step[k]) for k in range(1, max_steps_back + 1)]
        # Offset each dataset's bars
        offset = (i - total_bars / 2 + 0.5) * bar_width
        plt.bar(x + offset, avg_deltas, width=bar_width,yerr=std_deltas,capsize=3,  label=ds.replace("do_not", "Do-Not-Answer").replace("natural_100", "NaturalQA").replace("triviaqa_100", "TriviaQA").replace("sorry_100", "Sorry").replace("sycophancy_negative","S-neg").replace("sycophancy", "S-pos").replace("_100","").title(), alpha=0.7, color=color_blind_palette[i % len(color_blind_palette)])
    # add an average line across all datasets
    avg_all = [np.mean(all_deltas_per_step[k]) for k in range(1, max_steps_back + 1)]
    std_all = [np.std(all_deltas_per_step[k]) for k in range(1, max_steps_back + 1)]
    offset_avg = (n_datasets - total_bars / 2 + 0.5) * bar_width
    plt.bar(
        x + offset_avg,
        avg_all,
        width=bar_width,
        yerr=std_all,
        capsize=3,
        label="Average",
        alpha=0.9,
        color="grey",  # Make it distinct (e.g., Black or Grey)
        hatch="//"  # Optional: Add texture to distinguish it further
    )
    plt.xticks(x, fontsize=25)  # Center tick labels on the groups

    plt.ylabel('$\Delta_{k}$', fontsize=30)
    plt.xlabel('Steps Back', fontsize=30)
    # plt.xticks(range(1, max_steps_back + 1), fontsize=14)
    plt.yticks(fontsize=25)
    plt.grid(True, linestyle='--', alpha=0.35)
    plt.legend(fontsize=25, frameon=True)
    plt.tight_layout()
    plt.savefig(
        f'plots/steps_back_correlation{title_addition}.pdf', format='pdf', dpi=300)
    plt.close()


def calculate_shapley_likelihood(all_results, max_steps_back=7):
    """
    """
    metrics = {}

    # Iterate through each order k
    for k in range(1, max_steps_back + 1):
        k_prev = k - 1

        # 1. Count transitions for Order k (Complex) and Order k-1 (Simple)
        # We need counts to calculate probabilities P(next | history)
        counts_k = defaultdict(int)  # N(context_k, next_state)
        totals_k = defaultdict(int)  # N(context_k)

        counts_prev = defaultdict(int)  # N(context_prev, next_state)
        totals_prev = defaultdict(int)  # N(context_prev)

        total_observations = 0

        # Iterate over all sequences (rows) in the dataset
        for seq in all_results:
            if len(seq) < k + 1:
                continue

            # Slide window
            for i in range(k, len(seq)):
                next_state = seq[i]

                # History for Order k (e.g., last k tokens)
                hist_k = tuple(seq[i - k: i])
                counts_k[(hist_k, next_state)] += 1
                totals_k[hist_k] += 1

                # History for Order k-1 (e.g., last k-1 tokens)
                if k_prev == 0:
                    hist_prev = ()  # Empty tuple for 0-order
                else:
                    hist_prev = tuple(seq[i - k_prev: i])

                counts_prev[(hist_prev, next_state)] += 1
                totals_prev[hist_prev] += 1

                total_observations += 1

        # 2. Calculate the Weighted Average Difference
        # We iterate over every unique transition that actually occurred
        total_prob_diff = 0.0

        # Iterate over all observed (history, next) pairs in Order k
        for (hist_k, next_state), count in counts_k.items():

            # P(next | hist_k)
            prob_k = count / totals_k[hist_k]

            # Find corresponding prev history
            if k_prev == 0:
                hist_prev = ()
            else:
                # The recent k-1 tokens are a slice of the k tokens
                hist_prev = hist_k[1:]

                # P(next | hist_prev)
            # Safety check: if totals_prev is 0 (shouldn't happen if logic is correct), return 0
            if totals_prev[hist_prev] > 0:
                prob_prev = counts_prev[(hist_prev, next_state)] / totals_prev[hist_prev]
            else:
                prob_prev = 0.0

            # The gain in probability for this specific event
            diff = prob_k - prob_prev

            # Add to total, weighted by how often this event happened
            total_prob_diff += diff* count

        # Average over all observations to get Shapley value for order k
        shapley_value = total_prob_diff / total_observations if total_observations > 0 else 0.0
        metrics[k] = shapley_value

    return metrics



from collections import defaultdict


def calculate_hallucination_deltas(all_results, max_steps_back=7):
    """
    Calculates the 'Stickiness Delta' for history lengths from 1 to max_steps_back.
    Delta_k = P(1 | 1...1 [k times]) - P(1 | 01...1 [k-1 times])
    """
    metrics = {}

    # Loop through each "step back" depth (from 1 step back to 7 steps back)
    for k in range(1, max_steps_back + 1):

        # Dictionaries to count transitions for this specific history length
        # counts[(history_tuple, next_token)] -> count
        transition_counts = defaultdict(int)
        # context_counts[history_tuple] -> total occurrences of this history
        context_counts = defaultdict(int)

        for result in all_results:
            # We need at least k+1 tokens to have a history of k and a current token
            if len(result) < k + 1:
                continue

            # Slide a window of size k+1 over the result
            for i in range(len(result) - k):
                # The history is the previous k tokens
                history = tuple(result[i: i + k])
                # The next state is the current token
                next_state = result[i + k]

                transition_counts[(history, next_state)] += 1
                context_counts[history] += 1

        # --- Calculate the specific Delta for depth k ---

        # 1. Define the "Deep Loop" history: k times '1' (e.g., 1-1-1)
        hist_deep = tuple([1] * k)

        # 2. Define the "Recent Loop" history: '0' then k-1 times '1' (e.g., 0-1-1)
        hist_recent = tuple([0] + [1] * (k - 1))

        # Helper function to calculate P(1 | history) safely
        def get_prob_hallucination(history):
            total_seen = context_counts[history]
            if total_seen == 0:
                return 0.0  # Avoid division by zero if this state never happened

            count_hallucinations = transition_counts[(history, 1)]
            return count_hallucinations / total_seen

        prob_deep = get_prob_hallucination(hist_deep)
        prob_recent = get_prob_hallucination(hist_recent)

        delta = prob_deep - prob_recent
        metrics[k] = delta
    return metrics

def degrees_cw(all_observations, all_layers = False):
    layers = [-2] if not all_layers else [0,1,2,3]
    per_layer_degrees = {}
    for layer in layers:
        degrees_per_file = {}
        for file, obs in all_observations.items():
            print(f"Processing file: {file}")
            all_inner_states = obs["hidden_states"]
            degree_middle_layers = []
            for r in all_inner_states:
                inner_list = []
                for a in r:
                    inner_list.append((a[0][layer], a[1]))
                degree_middle_layers.append(inner_list)
            # print the probability of 1 per position in the sequence using the second element of the tuple [0.9,..]
            all_probs = []
            for i in range(len(degree_middle_layers)):
                probs = [a[1] for a in degree_middle_layers[i]]
                all_probs.append(probs)
            avg_probs = np.mean(all_probs, axis=0)
            print(f"Average probabilities per position: {avg_probs} {min(avg_probs)=} {max(avg_probs)=} {max(avg_probs)-min(avg_probs)=}")
            degrees_per_file[file] = {}
            random.seed(42)
            degrees = plot_hidden_states_through_time(degree_middle_layers, calculate_angle_per_point=True)
            degrees_per_file[file]["degrees_42"] = degrees
            random.seed(7)
            degrees = plot_hidden_states_through_time(degree_middle_layers, calculate_angle_per_point=True)
            degrees_per_file[file]["degrees_7"] = degrees
            random.seed(21)
            degrees = plot_hidden_states_through_time(degree_middle_layers, calculate_angle_per_point=True)
            degrees_per_file[file]["degrees_21"] = degrees
        per_layer_degrees[layer] = degrees_per_file
    if not all_layers:
        return per_layer_degrees[-2]
    return per_layer_degrees


def calculate_trace_and_theta_different_length(all_observations, length=10, two_topics_41=False):
    layers = [-2]
    per_layer_degrees = {}
    for layer in layers:
        degrees_per_file = {}
        for file, obs in all_observations.items():
            print(f"Processing file: {file}")
            all_inner_states = obs["hidden_states"]
            degree_middle_layers = []
            for r in all_inner_states:
                inner_list = []
                for index,a in enumerate(r[:length]):  # only take the first 'length' hidden states
                    if two_topics_41:
                        # this list is t1,t1,t1,t1,t2,t1, t2 every 5 steps, we will take examples only if they follow t2,t1 as the last two topics, the 7th example is the first such example
                        if index == 6 or (index > 6 and (index - 6) % 5 == 0):
                            inner_list.append((a[0][layer], a[1]))
                    else:
                        inner_list.append((a[0][layer], a[1]))
                degree_middle_layers.append(inner_list)
            degrees_per_file[file] = {}
            random.seed(42)
            degrees = plot_hidden_states_through_time(degree_middle_layers, calculate_angle_per_point=True)
            degrees_per_file[file]["degrees_42"] = degrees
            random.seed(7)
            degrees = plot_hidden_states_through_time(degree_middle_layers, calculate_angle_per_point=True)
            degrees_per_file[file]["degrees_7"] = degrees
            random.seed(21)
            degrees = plot_hidden_states_through_time(degree_middle_layers, calculate_angle_per_point=True)
            degrees_per_file[file]["degrees_21"] = degrees
            transitions = {"0-1": 0, "1-0": 0, "0-0": 0, "1-1": 0}
            chi =[]
            p = []
            all_observations_two_by_two = []
            print(f" All results shape {len(degree_middle_layers)}  {len(degree_middle_layers[0])}")

            for result in degree_middle_layers:
                cur_transitions = {"0-1": 0, "1-0": 0, "0-0": 0, "1-1": 0}
                for j in range(0, len(result) - 1):
                    transition = f"{result[j - 1][1]}-{result[j][1]}"
                    transitions[transition] += 1
                    cur_transitions[transition] += 1
                # calc chi squared test
                observed = [[cur_transitions["0-0"], cur_transitions["0-1"]],
                            [cur_transitions["1-0"], cur_transitions["1-1"]]]

                assert sum([sum(row) for row in observed]) == len(result) - 1, f"Expected sum of observations to be {len(result) - 2}, but got {sum([sum(row) for row in observed])}"


            final_2_2 = {"F-H": transitions["0-1"] / (transitions["0-1"] + transitions["0-0"]) if (transitions["0-1"] +
                                                                                                   transitions[
                                                                                                       "0-0"]) > 0 else 0,
                         "H-F": transitions["1-0"] / (transitions["1-0"] + transitions["1-1"]) if (transitions["1-0"] +
                                                                                                   transitions[
                                                                                                       "1-1"]) > 0 else 0,
                         "H-H": transitions["1-1"] / (transitions["1-1"] + transitions["1-0"]) if (transitions["1-1"] +
                                                                                                   transitions[
                                                                                                       "1-0"]) > 0 else 0,
                         "F-F": transitions["0-0"] / (transitions["0-0"] + transitions["0-1"]) if (transitions["0-0"] +
                                                                                                   transitions[
                                                                                                       "0-1"]) > 0 else 0}
            print("F-H:", round(final_2_2["F-H"], 4))
            print("H-F:", round(final_2_2["H-F"], 4))
            print("H-H:", round(final_2_2["H-H"], 4))
            print("F-F:", round(final_2_2["F-F"], 4))
            degrees_per_file[file]['2-2'] = final_2_2
            degrees_per_file[file]['hidden_states'] = obs["hidden_states"]
        per_layer_degrees[layer] = degrees_per_file
    return per_layer_degrees[-2]


def correlation_per_layer(all_observations):
    """Bar chart for layer-wise correlations."""
    degrees_cw_all_layers = degrees_cw(all_observations, all_layers=True)

    layers_correlations = {}
    for layer, degrees_per_file in degrees_cw_all_layers.items():
        print(f"Calculating correlations for layer {layer}")
        s,p = compute_all_consistencies(all_observations, split_models=[["gpt"], ["llama"], ["Qwen"]], ordered="True", degree=True, degrees_per_file=degrees_per_file)
        layers_correlations[layer] = (s,p)

    # plot a bar chart of the correlations per layer
    color_blind_palette = ["#D55E00", "#E69F00", "#0072B2", "#56B4E9", "#009E73", "#CC79A7"]
    plt.figure(figsize=(10, 6))
    layers = list(layers_correlations.keys())
    true_layers = [0.3,0.5,0.85,1] # in percentage of the total layers
    layers = [str(int(l*100))+"%" for l in true_layers]
    layers = ["Bottom", "Middle", "Upper", "Top"]
    correlations = [layers_correlations[layer][0] for layer in layers_correlations.keys()]
    bars = plt.bar(layers, correlations, color=color_blind_palette[:len(layers)],
                   edgecolor='black', linewidth=0.8, width=0.6)
    plt.xlabel("Layer", fontsize=30)
    plt.ylabel("Correlation", fontsize=30)
    plt.yticks(fontsize=25)
    plt.xticks(rotation=30, fontsize=25)
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(0.8)
    ax.spines['bottom'].set_linewidth(0.8)

    # Subtle gridlines
    ax.yaxis.grid(True, linestyle='--', alpha=0.3, linewidth=0.5)
    ax.set_axisbelow(True)
    plt.tight_layout()
    plt.savefig('plots/correlations_per_layer.pdf', format='pdf', dpi=300)
    plt.close()



if __name__ == "__main__":
    if not os.path.exists("plots"):
        os.makedirs("plots")
    json_files = [f for f in os.listdir("results/") if f.endswith(".json") and not f.endswith("_.json") and not "gpt-4o" in f]
    all_observations = {}

    for json_file in json_files:
        with open(os.path.join("results/", json_file), "r") as f:
            data = json.load(f)
            all_observations[json_file] = data
    print(f"Loaded {len(all_observations)} observation files.")
    observations_false = {k: v for k, v in all_observations.items() if "False" in k}
    observations_true = {k: v for k, v in all_observations.items() if "True" in k and "two_topics" not in k}
    observations_true_two_topics = {k: v for k, v in all_observations.items() if "True" in k and "two_topics" in k and "41" not in k}
    observations_true_two_topics41 = {k: v for k, v in all_observations.items() if "True" in k and "two_topics41" in k}


    observations_per_file_true_two_topics = calculate_trace_and_theta_different_length(observations_true_two_topics, length=20)

    twenty_length_degrees = calculate_trace_and_theta_different_length(observations_true, length=20)

    observations_true = twenty_length_degrees

    observations_false = calculate_trace_and_theta_different_length(observations_false, length=20)


    plot_roc_curve_per_model(observations_true,split_models=[["gpt"],["llama"], ["Qwen"]], ordered ="True")

    plot_lambda(observations_true,split_models=[["gpt"],["llama"], ["Qwen"]],ordered ="True")

    steps_back_correlation(observations_true, max_steps_back=3)

    correlation_per_layer(observations_true)

    print("Info two topics data length 10")
    ten_length_degrees = calculate_trace_and_theta_different_length(observations_true, length=10)
    print("Info two topics data length 5")
    five_length_degrees = calculate_trace_and_theta_different_length(observations_true, length=5)
    print("Info two topics data length 15")

    fifteen_length_degrees = calculate_trace_and_theta_different_length(observations_true, length=15)
    print("Info two topics 4,1 data length 20")

    two_topics_4_1_degrees = calculate_trace_and_theta_different_length(observations_true_two_topics41, length=20, two_topics_41=True)


    print("results for ordered = True")
    results = compute_all_consistencies(observations_true, split_models=[["llama"], ["gpt"], ["Qwen"]], ordered="True",
                                        degree=True, degrees_per_file=observations_true)

    print("results for ordered = False")
    results = compute_all_consistencies(observations_false, split_models=[["llama"], ["gpt"], ["Qwen"]], ordered="False",
                                        degree=True, degrees_per_file=observations_false)

    print("results for two topics")
    results = compute_all_consistencies(observations_per_file_true_two_topics, split_models=[["llama"], ["gpt"], ["Qwen"]],
                                        ordered="True",
                                        degree=True, degrees_per_file=observations_per_file_true_two_topics,two_topics=True)

    print("results for two topics 4-1")
    results = compute_all_consistencies(two_topics_4_1_degrees, split_models=[["llama"], ["gpt"], ["Qwen"]],
                                        ordered="True",
                                        degree=True, degrees_per_file=two_topics_4_1_degrees,
                                        title_addition=" (Two Topics 4-1)")

    print("results for length = 10")

    results = compute_all_consistencies(ten_length_degrees, split_models=[["llama"], ["gpt"], ["Qwen"]],
                                        ordered="True",
                                        degree=True, degrees_per_file=ten_length_degrees, title_addition=" (Length=10)")

    print("results for length = 5")
    results = compute_all_consistencies(five_length_degrees, split_models=[["llama"], ["gpt"], ["Qwen"]],
                                        ordered="True",
                                        degree=True, degrees_per_file=five_length_degrees, title_addition=" (Length=5)")
    print("results for length = 15")
    results = compute_all_consistencies(fifteen_length_degrees, split_models=[["llama"], ["gpt"], ["Qwen"]],
                                        ordered="True",
                                        degree=True, degrees_per_file=fifteen_length_degrees, title_addition=" (Length=15)")



