# Datasets-----------------------------------------------

# city_prob
# common_claim_prob
# counterfact_prob
# company_prob

# city_prob_llama13b
# common_claim_prob_llama13b
# counterfact_prob_llama13b
# company_prob_llama13b

# city_prob_mistral
# common_claim_prob_mistral
# counterfact_prob_mistral
# company_prob_mistral

# city_prob_qwen
# common_claim_prob_qwen
# counterfact_prob_qwen
# company_prob_qwen

# city_prob_granite
# common_claim_prob_granite
# counterfact_prob_granite
# company_prob_granite

# city_prob_qwen14b
# common_claim_prob_qwen14b
# counterfact_prob_qwen14b
# company_prob_qwen14b

# city_prob_phi4b
# common_claim_prob_phi4b
# counterfact_prob_phi4b
# company_prob_phi4b

# city_prob_vi
# common_claim_prob_vi
# counterfact_prob_vi
# company_prob_vi

# data preprocessing----------------------------------------
# map prob from [0,1] to [-1,1]
# map prediction from T/F to 1/-1
# map label from 1/0 to 1/-1

# parameters--------------------------------------------
layers_idx = 31 # only consider the last layers_idx layers
token_idx = 6 # only consider the last token_idx token
temperature = 0.05

# -------------------------------------------------------
import json
import copy
import numpy as np
from scipy.special import softmax
data = []
with open("../../data/city_prob.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        obj = json.loads(line)
        all_probs = obj["result"]["query"]["query_probs"]
        all_probs_np = np.array(all_probs)
        quer_probs = [row[-layers_idx-1:-1] for row in all_probs[-token_idx:]]

        # scaled_quer_probs = quer_probs
        scaled_quer_probs = [[2 * (x - 0.5) for x in row] for row in quer_probs]
        original_scaled_quer_probs = copy.deepcopy(scaled_quer_probs)

        all_importances = [score for _, score in obj["result"]["query"]["token_score"]] # get several last token importance and maxmin scale them, set the last one as 1.0 importance
        last_importances = all_importances[-token_idx:]

        if len(last_importances) == 1:
            scaled_importance = [1.0]

        else:
            last_importances = np.array(last_importances)
            scaled_importance = softmax(last_importances[:-1] / temperature).tolist()
            scaled_importance.append(1.0)
        importance_token = np.array(scaled_importance)
        # importance_token = np.array([1.0] * token_idx)  # remove token weights

        # default layer importance
        importance_layer = np.array([1.0] * layers_idx) # set all layer weights as 1.0
        importance_layer = importance_layer[-layers_idx:]

        # token_importance
        importance_2D = np.outer(importance_token, importance_layer)
        scaled_quer_probs_np = np.array(scaled_quer_probs)
        scaled_quer_probs_np *= importance_2D
        scaled_quer_probs = scaled_quer_probs_np.tolist()

        data.append({"scaled_probs": scaled_quer_probs,
                     "original_scaled_prob": original_scaled_quer_probs,
                     "llm_prediction": 1 if obj["result"]["answer"]["answer"].strip().startswith("T") else -1,
                     "llm_prediction_value": all_probs[-1][-1],
                     "label": 1 if obj["label"] == 1 else -1,
                     "query_content": obj["result"]["query"]["query"],
                     "AAEs": []})


print(f"The dataset contains {len(data)} samples")
rows = len(data[0]["scaled_probs"])
cols = len(data[0]["scaled_probs"][0])
print(f"Each sample's scaled_probs has shape: {rows} rows × {cols} columns")

pos, neg = 0, 0
for i in range(len(data)):
    if data[i]["label"] == 1:
        pos += 1
    if data[i]["label"] == -1:
        neg += 1
print(f"number of positive samples: {pos}")
print(f"number of negative samples: {neg}")


# generate a QBAF with n nodes
def generate_graph(rows, cols):

    # initialise
    n = rows * cols
    graph = {node: set() for node in range(n)}

    # go through every column
    for col in range(cols):
        col_nodes = [row * cols + col for row in range(rows)]

        # in one column, link the nodes
        for i in range(len(col_nodes)-1):
            j = i + 1
            u = col_nodes[i]
            v = col_nodes[j]
            graph[u].add(v) # comment out for last node

        # in the first row, link the nodes
        last_row_nodes = [(rows - 1) * cols + col for col in range(cols)]
        for i in range(len(last_row_nodes)-1):
            j = i + 1
            u = last_row_nodes[i]
            v = last_row_nodes[j]
            graph[u].add(v) # comment out for last node
    return graph

# check the generated QBAF
graph = generate_graph(rows, cols)

# generate a random acyclic QBAF and output to a file
def generate_and_write_graph(filename, i, graph, data, rows, cols):
    with open(filename, 'w') as f:
        sys.stdout = f

        # generate base scores for arguments
        flat_data = [item for sublist in data[i]["scaled_probs"] for item in sublist]
        for node in graph:
            print(f"arg({node}, {flat_data[node]}).")

        # generate polarity for edges
        for node, edges in graph.items():
            for edge in edges:
                if (flat_data[node] < 0 and flat_data[edge] < 0) or (flat_data[node] >= 0 and flat_data[edge] >= 0):
                    print(f"sup({node}, {edge}).")
                else:
                    print(f"att({node}, {edge}).")

    sys.stdout = sys.__stdout__

# generate N QBAFs storing in N files
N = len(data)
for i in range(N):
    filename = f'../../bags/llmQBAF_{i}.bag'
    generate_and_write_graph(filename, i, graph, data, rows, cols)

import sys
sys.path.append("../../src/")
import numpy as np
import uncertainpy.gradual as grad

# load a QBAF
bag = grad.BAG("../../bags/llmQBAF_500.bag")

# QE semantics with Probability
agg_f = grad.semantics.modular.SumAggregation_probability()
inf_f = grad.semantics.modular.QuadraticMaximumInfluence_probability(conservativeness=1.0)

# compute strength
strength_values = grad.algorithms.computeStrengthValues(bag, agg_f, inf_f)
# for arg in bag.arguments.values():
#     print((arg.name,arg.strength))

# use the layer before the last layer
from tqdm import tqdm
def row_nodes(r, rows=layers_idx, cols=token_idx):
    return [r + k*rows for k in range(cols)]  # r, r+rows, ..., r+(cols-1)*rows

print("Conducting statistical analysis for all samples...")
for i in tqdm(range(N)):
    bag = grad.BAG(f"../../bags/llmQBAF_{i}.bag")
    strength_values = grad.algorithms.computeStrengthValues(bag, agg_f, inf_f)
    data[i]["QBAF_prediction"] = 1 if bag.arguments[str(rows * cols - 1)].strength >= 0 else -1
    data[i]["qbaf_prediction_value"] = bag.arguments[str(rows * cols - 1)].strength

import random
random.seed(0)
for i in range(len(data)):
    flat_data = [item for sublist in data[i]["original_scaled_prob"] for item in sublist]

    rdm = random.choice(flat_data) # 5*5=25 elements
    data[i]["random"] = 1 if rdm >= 0 else -1

    avg = sum(flat_data) / len(flat_data)
    data[i]["average"] = 1 if avg >= 0 else -1

    bools = [x >= 0 for x in flat_data]
    data[i]["majority"] = 1 if bools.count(True) >= bools.count(False) else -1


# print confusion matrix and print some wrong samples------------------------------
y_llm_prediction = np.array([x["llm_prediction"] for x in data])
y_QBAF_pred = np.array([x["QBAF_prediction"] for x in data])

from sklearn.metrics import confusion_matrix
tn, fp, fn, tp = confusion_matrix(y_llm_prediction, y_QBAF_pred).ravel()
print(f"TN={tn}, FP={fp}, FN={fn}, TP={tp}")

# Assuming classes: -1 = negative, 1 = positive
tp = np.where((y_llm_prediction == 1) & (y_QBAF_pred == 1))[0]
fp = np.where((y_llm_prediction == -1) & (y_QBAF_pred == 1))[0]
fn = np.where((y_llm_prediction == 1) & (y_QBAF_pred == -1))[0]
tn = np.where((y_llm_prediction == -1) & (y_QBAF_pred == -1))[0]

# Statistics of the final result-------------------------------------
count_QBAF_LLM = 0
count_random_LLM = 0
count_average_LLM = 0
count_majority_LLM = 0
count_QBAF_Lable = 0
count_LLM_Label = 0

for i in range(len(data)):
    if data[i]["llm_prediction"] == data[i]["QBAF_prediction"]:
        count_QBAF_LLM += 1
    if data[i]["llm_prediction"] == data[i]["random"]:
        count_random_LLM += 1
    if data[i]["llm_prediction"] == data[i]["average"]:
        count_majority_LLM += 1
    if data[i]["llm_prediction"] == data[i]["majority"]:
        count_average_LLM += 1
    if data[i]["llm_prediction"] == data[i]["label"]:
        count_LLM_Label += 1
    if data[i]["label"] == data[i]["QBAF_prediction"]:
        count_QBAF_Lable += 1
        data[i]["label_QBAF_consistent"] = 0 # no hallucination so 0
    else:
        data[i]["label_QBAF_consistent"] = 1 # hallucination so 1

print("number of samples:", len(data))
print(f"random: {count_random_LLM}")
print(f"average: {count_average_LLM}")
print(f"majority: {count_majority_LLM}")
print(f"QBAF: {count_QBAF_LLM}")

# compute AUC
llm_pred = []
qbaf_pred = []
label = []
label_QBAF_consis = []
for i in range(len(data)):
    label_QBAF_consis.append((data[i]["label_QBAF_consistent"]))
    llm_pred.append((data[i]["llm_prediction_value"]))
    qbaf_pred.append((data[i]["qbaf_prediction_value"] / 2) + 0.5)
    if data[i]["label"] == -1:
        label.append(0)
    else:
        label.append(1)

from sklearn.metrics import roc_auc_score
auc_llm = roc_auc_score(label, llm_pred)
auc_QBAF = roc_auc_score(label, qbaf_pred)
print(f"AUC_LLM: {auc_llm}")
print(f"AUC_QBAF: {auc_QBAF}")