import json
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import argparse
import os, sys, inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)

tokenizer = AutoTokenizer.from_pretrained("rambodazimi/bert-base-uncased-finetuned-FFT-QQP")
model = AutoModelForSequenceClassification.from_pretrained("rambodazimi/bert-base-uncased-finetuned-FFT-QQP")

MAX_CLARIFICATION=10

parser = argparse.ArgumentParser()
parser.add_argument("--log_path", type = str, required = True)
parser.add_argument("--output_path", type = str, required = True)
args = parser.parse_args()


def get_probs(question1, question2,model,tokenizer):
    #Calculation of bert output
    inputs = tokenizer(question1, question2, return_tensors="pt", truncation=True, padding=True)
    outputs = model(**inputs)
    logits = outputs.logits
    probs = torch.softmax(logits, dim=1)
    return (probs[0][1].item())



    

def weighted_pairwise_similarity_qqp_2(texts, question, model, tokenizer):
    #similarity between clarifications and original question
    weights = torch.tensor([get_probs(question, el, model, tokenizer) for el in texts])
    n = min(len(texts),MAX_CLARIFICATION)
    weights=weights[:n]
    sim_matrix = np.zeros((n, n))
    #Calculation of antisimilarity matrix
    for i in range(n):
        for j in range(n):
            #antisimilarity between clarifications
            inputs = tokenizer(texts[i], texts[j], return_tensors="pt", truncation=True, padding=True)
            with torch.no_grad():
                outputs = model(**inputs)
                probs = torch.softmax(outputs.logits, dim=1)
            sim_matrix[i, j] = 1-probs[0][1]  
        sim_matrix[i, i]=0

    sim_matrix=torch.tensor(sim_matrix)

    S=weights.sum()

    #normalisation of similarity between original question and  clarifications
    weights=n*weights/weights.sum()
    w_outer = weights.unsqueeze(1) * weights.unsqueeze(0)  # [N, N]
    
    numerator = (sim_matrix * w_outer).sum()
    if n!=0 :
        res=sim_matrix.sum()/n**2
    else :
        res=0
    return numerator.item(),sim_matrix.sum(),S,res, n
    



def main(args) : 
    merged = {}
    for filename in os.listdir(args.log_path):
        if filename.endswith(".json"):
            file_path = os.path.join(args.log_path, filename)
            print(file_path)
            with open(file_path, "r", encoding="utf-8") as f:
                try:
                    data = json.load(f)
                    for entry in data:
                        key = (entry["orig_instruction"], tuple(entry["input"]))  
                        if key not in merged:
                            merged[key] = {
                                "orig_instruction": entry["orig_instruction"],
                                "input": entry["input"],
                                "isambig": entry["isambig"],
                                "CLARAoq" : [],
                                "CLARA" : [],
                                "w_score_list" :[],
                                "CLARAn":[],
                                "n_list":[]

                            }
                        a,b,c,d,n=weighted_pairwise_similarity_qqp_2(entry["self_clarification"], entry["orig_instruction"], model, tokenizer)

                        merged[key]["CLARAoq"].append(float(a))
                        merged[key]["CLARA"].append(float(b))
                        merged[key]["w_score_list"].append(float(c))
                        merged[key]["CLARAn"].append(float(d))
                        merged[key]["n_list"].append(float(n))
                except Exception as e:
                    print(f"Erreur dans {filename} : {e}")
    merged_list = list(merged.values())

    with open(args.output_path, 'w', encoding='utf-8') as f:
        json.dump(merged_list, f, indent=4, ensure_ascii=False)

    


if __name__ == '__main__':
    main(args)







