# Datasets-----------------------------------------------

# city_prob
# common_claim_prob
# counterfact_prob
# company_prob

# city_prob_llama13b
# common_claim_prob_llama13b
# counterfact_prob_llama13b
# company_prob_llama13b

# city_prob_mistral
# common_claim_prob_mistral
# counterfact_prob_mistral
# company_prob_mistral

# city_prob_qwen
# common_claim_prob_qwen
# counterfact_prob_qwen
# company_prob_qwen

# city_prob_granite
# common_claim_prob_granite
# counterfact_prob_granite
# company_prob_granite

# city_prob_qwen14b
# common_claim_prob_qwen14b
# counterfact_prob_qwen14b
# company_prob_qwen14b

# city_prob_phi4b
# common_claim_prob_phi4b
# counterfact_prob_phi4b
# company_prob_phi4b

# city_prob_vi
# common_claim_prob_vi
# counterfact_prob_vi
# company_prob_vi

# data preprocessing----------------------------------------
# map prob from [0,1] to [-1,1]
# map prediction from T/F to 1/-1
# map label from 1/0 to 1/-1

# parameters--------------------------------------------
layers_idx = 31 # only consider the last layers_idx layers
token_idx = 6 # only consider the last token_idx token
temperature = 0.05

# -------------------------------------------------------
import json
import copy
import numpy as np
from scipy.special import softmax
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt

data = []
data_path = "../../data/common_claim_prob.jsonl"
with open(data_path, "r", encoding="utf-8") as f:
    for line in f:
        obj = json.loads(line)
        all_probs = obj["result"]["query"]["query_probs"]
        all_probs_np = np.array(all_probs)
        quer_probs = [row[-layers_idx-1:-1] for row in all_probs[-token_idx:]]

        # scaled_quer_probs = quer_probs
        scaled_quer_probs = [[2 * (x - 0.5) for x in row] for row in quer_probs]
        original_scaled_quer_probs = copy.deepcopy(scaled_quer_probs)

        all_importances = [score for _, score in obj["result"]["query"]["token_score"]] # get several last token importance and maxmin scale them, set the last one as 1.0 importance
        last_importances = all_importances[-token_idx:]

        if len(last_importances) == 1:
            scaled_importance = [1.0]
        else:
            last_importances = np.array(last_importances)
            scaled_importance = softmax(last_importances[:-1] / temperature).tolist()
            scaled_importance.append(1.0)
        importance_token = np.array(scaled_importance)
        # importance_token = np.array([1.0] * token_idx)  # remove token weights



        # default layer importance (all equals to 1.0)
        # importance_layer = np.array([1.0] * layers_idx) # to be extended in the future

        # city
        # importance_layer = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0001, 0.0001, 0.0, 0.0001, 0.0001, 0.0, 0.003, 0.0169, 0.0386, 0.059, 0.0386, 0.0303, 0.0239, 0.0244, 0.0418, 0.0432, 0.0641, 0.0628, 0.092, 0.0795, 0.0642, 0.0687, 0.0622, 0.0743, 0.1119]

        # # common claim
        importance_layer = [0.0002, 0.0002, 0.0003, 0.0002, 0.0011, 0.0015, 0.0016, 0.0001, 0.0001, 0.0003, 0.0001, 0.0001, 0.0005, 0.0037, 0.0042, 0.0255, 0.0265, 0.0215, 0.0374, 0.0328, 0.0375, 0.0619, 0.0731, 0.082, 0.0883, 0.0961, 0.11, 0.0917, 0.0604, 0.0669, 0.0742]

        # # counterfact
        # importance_layer = [0.0001, 0.0001, 0.0001, 0.0002, 0.0004, 0.0001, 0.0001, 0.0001, 0.0, 0.0002, 0.0003, 0.0001, 0.0004, 0.0021, 0.0146, 0.0283, 0.0253, 0.0221, 0.0417, 0.0515, 0.0404, 0.0635, 0.0607, 0.0473, 0.0662, 0.0955, 0.1016, 0.0863, 0.0769, 0.0688, 0.1052]

        # company
        # importance_layer = [0.0003, 0.0006, 0.0003, 0.0006, 0.0003, 0.0004, 0.0007, 0.0002, 0.0003, 0.0003, 0.0002, 0.0007, 0.0034, 0.0166, 0.0535, 0.0644, 0.057, 0.0436, 0.077, 0.0845, 0.0495, 0.0639, 0.0522, 0.0664, 0.0596, 0.0697, 0.0442, 0.0331, 0.0434, 0.0455, 0.0677]

        importance_layer = importance_layer[-layers_idx:]

        # token_importance
        importance_2D = np.outer(importance_token, importance_layer)
        scaled_quer_probs_np = np.array(scaled_quer_probs)
        scaled_quer_probs_np *= importance_2D
        scaled_quer_probs = scaled_quer_probs_np.tolist()

        data.append({"scaled_probs": scaled_quer_probs,
                     "original_scaled_prob": original_scaled_quer_probs,
                     "llm_prediction": 1 if obj["result"]["answer"]["answer"].strip().startswith("T") else -1,
                     "llm_prediction_value": all_probs[-1][-1],
                     "label": 1 if obj["label"] == 1 else -1,
                     "query_content": obj["result"]["query"]["query"],
                     "AAEs": []})


print(f"The dataset contains {len(data)} samples")
rows = len(data[0]["scaled_probs"])
cols = len(data[0]["scaled_probs"][0])
print(f"Each sample's scaled_probs has shape: {rows} rows × {cols} columns")

pos, neg = 0, 0
for i in range(len(data)):
    if data[i]["label"] == 1:
        pos += 1
    if data[i]["label"] == -1:
        neg += 1

print(f"pos:{pos}")
print(f"neg:{neg}")


# generate a QBAF with n nodes
def generate_graph(rows, cols):

    # initialise
    n = rows * cols
    graph = {node: set() for node in range(n)}

    # go through every column
    for col in range(cols):
        col_nodes = [row * cols + col for row in range(rows)]

        # in one column, link the nodes
        for i in range(len(col_nodes)-1):
            j = i + 1
            u = col_nodes[i]
            v = col_nodes[j]
            graph[u].add(v) # comment out for last node

        # in the first row, link the nodes
        last_row_nodes = [(rows - 1) * cols + col for col in range(cols)]
        for i in range(len(last_row_nodes)-1):
            j = i + 1
            u = last_row_nodes[i]
            v = last_row_nodes[j]
            graph[u].add(v) # comment out for last node
    return graph

# check the generated QBAF
graph = generate_graph(rows, cols)
# print(graph)

# generate a random acyclic QBAF and output to a file
# node_index starting from 28
def generate_and_write_graph(filename, i, graph, data, rows, cols):
    with open(filename, 'w') as f:
        sys.stdout = f

        # generate base scores for arguments
        flat_data = [item for sublist in data[i]["scaled_probs"] for item in sublist]
        for node in graph:
            print(f"arg({node}, {flat_data[node]}).")

        # generate polarity for edges
        for node, edges in graph.items():
            for edge in edges:
                if (flat_data[node] < 0 and flat_data[edge] < 0) or (flat_data[node] >= 0 and flat_data[edge] >= 0):
                    print(f"sup({node}, {edge}).")
                else:
                    print(f"att({node}, {edge}).")

    sys.stdout = sys.__stdout__

# generate N QBAFs storing in N files
N = len(data)
for i in range(N):
    filename = f'../../bags/llmQBAF_{i}.bag'
    generate_and_write_graph(filename, i, graph, data, rows, cols)

import sys
sys.path.append("../../src/")
import uncertainpy.gradual as grad

# load a QBAF
bag = grad.BAG("../../bags/llmQBAF_499.bag")

# QE semantics with Probability
agg_f = grad.semantics.modular.SumAggregation_probability()
inf_f = grad.semantics.modular.QuadraticMaximumInfluence_probability(conservativeness=1.0)

# compute strength
strength_values = grad.algorithms.computeStrengthValues(bag, agg_f, inf_f)
# for arg in bag.arguments.values():
#     print((arg.name, arg.strength))


feature_list = []
feature_list.append({"num_attacks": len(bag.attacks)})

for i in range(N):
    bag = grad.BAG(f"../../bags/llmQBAF_{i}.bag")
    strength_values = grad.algorithms.computeStrengthValues(bag, agg_f, inf_f)
    data[i]["QBAF_prediction"] = 1 if bag.arguments[str(rows * cols - 1)].strength >= 0 else -1
    data[i]["qbaf_prediction_value"] = bag.arguments[str(rows * cols - 1)].strength


for i in range(len(data)):
    if data[i]["label"] == data[i]["QBAF_prediction"]:
        data[i]["label_QBAF_consistent"] = 0 # no hallucination so 0
    else:
        data[i]["label_QBAF_consistent"] = 1 # hallucination so 1


# left 10 columns (that is, lower 10 layers)
lower_list = [i * layers_idx + j for i in range(token_idx) for j in range(10)]
# middle 10 columns (that is, middle 10 layers)
middle_list = [i * layers_idx + j for i in range(token_idx) for j in range(10, 20)]
# right 10 columns (that is, upper 10 layers)
upper_list = [i * layers_idx + j for i in range(token_idx) for j in range(20, layers_idx)]

features = [{} for _ in range(500)]
for i in range(500):
    bag = grad.BAG(f"../../bags/llmQBAF_{i}.bag")
    strength_values = grad.algorithms.computeStrengthValues(bag, agg_f, inf_f)
    arg_initial_strength = []
    arg_final_strength = []
    for arg in bag.arguments.values():
        arg_initial_strength.append(arg.initial_weight)
        arg_final_strength.append(arg.strength)

    # extract fine_grained left, middle, right part features
    lower_arg_initial_strength = []
    middle_arg_initial_strength = []
    upper_arg_initial_strength = []
    lower_arg_final_strength = []
    middle_arg_final_strength = []
    upper_arg_final_strength = []

    for j in lower_list:
        lower_arg_initial_strength.append(abs(bag.arguments[str(j)].initial_weight))
        lower_arg_final_strength.append(abs(bag.arguments[str(j)].strength))
    for j in middle_list:
        middle_arg_initial_strength.append(abs(bag.arguments[str(j)].initial_weight))
        middle_arg_final_strength.append(abs(bag.arguments[str(j)].strength))
    for j in upper_list:
        upper_arg_initial_strength.append(abs(bag.arguments[str(j)].initial_weight))
        upper_arg_final_strength.append(abs(bag.arguments[str(j)].strength))

    lower_mean_initial_strength = sum(lower_arg_initial_strength) / len(lower_arg_initial_strength)
    lower_var_initial_strength = sum((x - lower_mean_initial_strength) ** 2 for x in lower_arg_initial_strength) / len(lower_arg_initial_strength)
    features[i]["lower_avg_initial_strength"] = lower_mean_initial_strength
    features[i]["lower_var_initial_strength"] = lower_var_initial_strength
    lower_mean_final_strength = sum(lower_arg_final_strength) / len(lower_arg_final_strength)
    lower_var_final_strength = sum((x - lower_mean_final_strength) ** 2 for x in lower_arg_final_strength) / len(lower_arg_final_strength)
    features[i]["lower_avg_final_strength"] = lower_mean_final_strength
    features[i]["lower_var_final_strength"] = lower_var_final_strength

    middle_mean_initial_strength = sum(middle_arg_initial_strength) / len(middle_arg_initial_strength)
    middle_var_initial_strength = sum((x - middle_mean_initial_strength) ** 2 for x in middle_arg_initial_strength) / len(middle_arg_initial_strength)
    features[i]["middle_avg_initial_strength"] = middle_mean_initial_strength
    features[i]["middle_var_initial_strength"] = middle_var_initial_strength
    middle_mean_final_strength = sum(middle_arg_final_strength) / len(middle_arg_final_strength)
    middle_var_final_strength = sum((x - middle_mean_final_strength) ** 2 for x in middle_arg_final_strength) / len(middle_arg_final_strength)
    features[i]["middle_avg_final_strength"] = middle_mean_final_strength
    features[i]["middle_var_final_strength"] = middle_var_final_strength

    upper_mean_initial_strength = sum(upper_arg_initial_strength) / len(upper_arg_initial_strength)
    upper_var_initial_strength = sum((x - upper_mean_initial_strength) ** 2 for x in upper_arg_initial_strength) / len(upper_arg_initial_strength)
    features[i]["upper_avg_initial_strength"] = upper_mean_initial_strength
    features[i]["upper_var_initial_strength"] = upper_var_initial_strength
    upper_mean_final_strength = sum(upper_arg_final_strength) / len(upper_arg_final_strength)
    upper_var_final_strength = sum((x - upper_mean_final_strength) ** 2 for x in upper_arg_final_strength) / len(upper_arg_final_strength)
    features[i]["upper_avg_final_strength"] = upper_mean_final_strength
    features[i]["upper_var_final_strength"] = upper_var_final_strength

    # extract fine-grained structural pattern
    lower_num_attacks = 0
    for k in lower_list:
        if k < layers_idx * (token_idx - 1):  # then k+layers_idx <= layers_idx * token_idx
            if bag.arguments[str(k)] in bag.arguments[str(k + layers_idx)].attackers:
                lower_num_attacks += 1
    features[i]["lower_num_attacks"] = lower_num_attacks

    middle_num_attacks = 0
    for k in middle_list:
        if k < layers_idx * (token_idx - 1):
            if bag.arguments[str(k)] in bag.arguments[str(k + layers_idx)].attackers:
                middle_num_attacks += 1
    features[i]["middle_num_attacks"] = middle_num_attacks

    upper_num_attacks = 0
    for k in upper_list:
        if k < layers_idx * (token_idx - 1):
            if bag.arguments[str(k)] in bag.arguments[str(k + layers_idx)].attackers:
                upper_num_attacks += 1
    features[i]["upper_num_attacks"] = upper_num_attacks
    features[i]["label"] = data[i]["label_QBAF_consistent"]
print("Dataset Preparation Done")


from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier


label_key = "label"

# extract X and y
def to_X_y(data, label_key="label"):
    feature_names = [k for k in data[0].keys() if k != label_key]  # 特征列
    X = np.array([[sample[k] for k in feature_names] for sample in data], dtype=float)
    y = np.array([sample[label_key] for sample in data], dtype=int)
    return X, y, feature_names

X, y, feature_names = to_X_y(features, label_key=label_key)

# split train set and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)



from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score
import numpy as np
from sklearn.tree import DecisionTreeClassifier

# ===== MLP Model =====
pipe = Pipeline([
    ('scaler', StandardScaler()),  # Standardize continuous features
    ('mlp', MLPClassifier(
        hidden_layer_sizes=(32, 16),  # 2-layer MLP
        activation='relu',
        solver='adam',
        alpha=1e-4,        # L2 normalization
        learning_rate='adaptive',
        early_stopping=True,
        random_state=42,
        max_iter=300
    ))
])

# train MLP
pipe.fit(X_train, y_train)

# evaluate model
y_pred = pipe.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))
print(classification_report(y_test, y_pred, digits=3))

# Check whether the sample distribution is balanced
unique, counts = np.unique(y, return_counts=True)
print(dict(zip(unique, counts)))

# compute AUC
from sklearn.metrics import roc_auc_score
y_proba = pipe.predict_proba(X_test)[:, 1]
print("ROC-AUC:", roc_auc_score(y_test, y_proba))

# check overfitting
train_acc = pipe.score(X_train, y_train)
test_acc = pipe.score(X_test, y_test)
print(f"Train Accuracy: {train_acc:.3f}, Test Accuracy: {test_acc:.3f}")




# compute feature importance using SHAP
import shap
import numpy as np

# using 100 samples as background dataset to estimate the expected value
random_indices = np.random.choice(len(X), size=100, replace=False)

# extract background dataset randomly
background = X[random_indices]

# KernelExplainer
explainer = shap.KernelExplainer(pipe.predict_proba, background)

# compute SHAP
shap_values = explainer.shap_values(X[:500])

# explain class 1
shap_class1 = shap_values[:, :, 1]

# visualization
shap.summary_plot(shap_class1, X[:500], feature_names=feature_names, plot_type="dot")

# save to file
plt.savefig("shap_summary_plot.png", dpi=300, bbox_inches="tight")
plt.close()

print("SHAP summary plot saved to shap_summary_plot.png")

# compute mean absolute SHAP values for feature importance
mean_abs_shap = np.mean(np.abs(shap_class1), axis=0)
import pandas as pd
importance_df = pd.DataFrame({
    "Feature": feature_names,
    "Mean |SHAP|": mean_abs_shap
})

# sorting by importance
importance_df = importance_df.sort_values(by="Mean |SHAP|", ascending=False).reset_index(drop=True)
print(importance_df)

