import json, os

BASE = "./data/raw"

def mname(fn):
    return "_".join(fn.split("_")[:-3])

def extract_one_sample(x):
    query_id = x['session']['query_id']
    query_text = x['session']['query_text']
    top_k_returns_arr = []
    for turn in x['history']:
        search_query = turn['query']
        reflection = turn['reflection']
        top_k_returns = []
        for result in turn['eval_result']['results']:
            top_k_returns.append(result['text'])
        best_rank = turn['eval_result']['best_rank']
        best_cosine_score = turn['eval_result']['best_score']
        top_k_returns_arr.append({"search_query": search_query, "think": reflection, "top_k_results": top_k_returns, "best_rank": best_rank, "best_cosine": best_cosine_score, "model_name": x['model_name']})
    return top_k_returns_arr, query_id, query_text

def load():
    L, U = {}, {}
    for ds in os.listdir(BASE):
        p = os.path.join(BASE, ds)
        if not os.path.isdir(p): continue
        for f in os.listdir(p):
            if f.endswith(".jsonl"):
                mn = mname(f)
                with open(os.path.join(p, f)) as fh:
                    for line in fh:
                        if line.strip():
                            x = json.loads(line)
                            x["model_name"] = mn
                            turns, query_id, user_query = extract_one_sample(x)
                            query_id = ds + '_' + query_id
                            if L.get(query_id) is None:
                                L[query_id] = []
                                U[query_id] = {'user_query': user_query}
                            L[query_id].extend(turns)
        for qid in L:
            tt=L[qid]
            total=len(tt)
            below=sum(1 for t in tt if t["best_rank"]<5)
            above=total-below
            U[qid]["score"]={"below5":below/total,"above5":above/total} if total else {"below5":0,"above5":0}
    out = {"turns": L, "user_query": U}
    return out

if __name__ == "__main__":
    data = load()
    print(data.keys())
    print(data['user_query']['hotpotqa_5a7303685542991f9a20c5fe'], json.dumps(data['turns']['hotpotqa_5a7303685542991f9a20c5fe'], indent=4))
    with open("./data/preprocessed/data.json", "w") as f:
        json.dump(data, f, indent=4)