import json
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import argparse
import os, sys, inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)

tokenizer = AutoTokenizer.from_pretrained("rambodazimi/bert-base-uncased-finetuned-FFT-QQP")
model = AutoModelForSequenceClassification.from_pretrained("rambodazimi/bert-base-uncased-finetuned-FFT-QQP")

MAX_CLARIFICATION=10

parser = argparse.ArgumentParser()
parser.add_argument("--log_path", type = str, required = True)
parser.add_argument("--output_path", type = str, required = True)
args = parser.parse_args()

def get_probs(question1, question2,model,tokenizer):
    #Calculation of bert output
    inputs = tokenizer(question1, question2, return_tensors="pt", truncation=True, padding=True)
    outputs = model(**inputs)
    logits = outputs.logits
    probs = torch.softmax(logits, dim=1)
    return (probs[0][1].item())



    

def weighted_pairwise_similarity_qqp_2(texts, question, model, tokenizer):
    #similarity between clarifications and original question
    weights = torch.tensor([get_probs(question, el, model, tokenizer) for el in texts])
    n = min(len(texts),MAX_CLARIFICATION)
    weights=weights[:n]
    sim_matrix = np.zeros((n, n))
    #Calculation of antisimilarity matrix
    for i in range(n):
        for j in range(n):
            #antisimilarity between clarifications
            inputs = tokenizer(texts[i], texts[j], return_tensors="pt", truncation=True, padding=True)
            with torch.no_grad():
                outputs = model(**inputs)
                probs = torch.softmax(outputs.logits, dim=1)
            sim_matrix[i, j] = 1-probs[0][1]  
        sim_matrix[i, i]=0

    sim_matrix=torch.tensor(sim_matrix)

    S=weights.sum()

    #normalisation of similarity between original question and  clarifications
    weights=n*weights/weights.sum()
    w_outer = weights.unsqueeze(1) * weights.unsqueeze(0)  # [N, N]
    
    numerator = (sim_matrix * w_outer).sum()
    if n!=0 :
        res=sim_matrix.sum()/n**2
    else :
        res=0

    ambiguity,lambda2,H,s_fied,s_ent=ambiguity_from_similarity(1-sim_matrix, alpha= 0.5, eps= 1e-12)

    return numerator.item(),sim_matrix.sum(),S,res, n,ambiguity
    


def ambiguity_from_similarity(sim_matrix: torch.Tensor, alpha: float = 0.5, eps: float = 1e-12):
    A = (sim_matrix + sim_matrix.T) / 2
    A = torch.clamp(A, min=0)
    A = A.clone()
    A.fill_diagonal_(0.0)
    deg = A.sum(dim=1)
    D_inv_sqrt = torch.diag(torch.pow(deg + eps, -0.5))
    Lsym = torch.eye(A.size(0), device=A.device) - D_inv_sqrt @ A @ D_inv_sqrt

    evals, _ = torch.linalg.eigh(Lsym)          # triées croissant
    n = A.size(0)

    if n < 3:
        return (
            0.5,
            float(evals[1].item()) if n >= 2 else float('nan'),
            float('nan'),
            0.5,
            float('nan')
        )

    lambda2 = evals[1].clamp(min=0.0, max=2.0)
    s_fied = (1.0 - lambda2 / 2.0).item()  # in [0,1]

    tail = evals[1:]                  # λ2..λn
    S = tail.sum().item()
    if S <= eps:
        s_ent = 0.0
        H = 0.0
    else:
        p = (tail / (S + eps)).clamp(min=eps)
        H = float(-(p * torch.log(p)).sum().item())
        s_ent = float(H / (torch.log(torch.tensor(n-1.0)).item() + eps))  # normalisation

    ambiguity = float(alpha * s_fied + (1 - alpha) * s_ent)

    return (
        ambiguity,            # ∈ [0,1]
        float(lambda2.item()),
        H,             # non normalisée
        s_fied,                  # ∈ [0,1]
        s_ent                  # ∈ [0,1]
    )



def main(args) : 
    merged = {}
    for filename in os.listdir(args.log_path):
        if filename.endswith(".json"):
            file_path = os.path.join(args.log_path, filename)
            print(file_path)
            with open(file_path, "r", encoding="utf-8") as f:
                try:
                    data = json.load(f)
                    for entry in data:
                        key = (entry["question"], tuple(entry["label"]))  
                        if key not in merged:
                            merged[key] = {
                                "question": entry["question"],
                                "label": entry["label"],
                                "CLARAoq" : [],
                                "CLARA" : [],
                                "w_score_list" :[],
                                "CLARAn":[],
                                "n_list":[],
                                "laplace":[],

                            }
                        a,b,c,d,n,amb=weighted_pairwise_similarity_qqp_2(entry["self_clarification"], entry["question"], model, tokenizer)

                        merged[key]["CLARAoq"].append(float(a))
                        merged[key]["CLARA"].append(float(b))
                        merged[key]["w_score_list"].append(float(c))
                        merged[key]["CLARAn"].append(float(d))
                        merged[key]["n_list"].append(float(n))
                        merged[key]["laplace"].append(float(amb))

                except Exception as e:
                    print(f"Erreur dans {filename} : {e}")
    merged_list = list(merged.values())

    with open(args.output_path, 'w', encoding='utf-8') as f:
        json.dump(merged_list, f, indent=4, ensure_ascii=False)

    


if __name__ == '__main__':
    main(args)







