import os
import json
import jsonlines
import itertools
import numpy as np
from tqdm import tqdm

def read_jsonl_file(file_path):
    data = []
    with open(file_path, 'r') as file:
        for line in file:
            data.append(json.loads(line))
    return data

# model_name = 'Llama2_7B_Chat'
# subdir = '../data/all/'
model_name = 'Mistral_7B_Instruct'
subdir = '../data/all/'
split = 'nq'

temperature = 1.0
run = 0
save_path_dir = f'{subdir}/temperature={temperature}/run={run}'

save_path_0 = os.path.join(save_path_dir, f'{model_name}_{split}.jsonl')
data = read_jsonl_file(save_path_0)

def get_all_facts(input_dictionary):
    all_facts = []
    for x in input_dictionary:
        annotations = x['annotations']

        facts = [x['model-atomic-facts'] for x in annotations]
        facts = list(itertools.chain.from_iterable(facts))
        facts = [x['text'] for x in facts]

        all_facts.append(facts)
    return all_facts
    
additional_temperature = 1.0

all_additional_outputs = []
for i in range(1, 5):
    save_path = save_path_0.replace('run=0', f'run={i}')
    print(save_path)
    all_additional_outputs.append(read_jsonl_file(save_path))
    all_additional_outputs[-1] = get_all_facts(all_additional_outputs[-1])

from rank_bm25 import BM25Okapi
import nltk
nltk.download('punkt')

def get_top(input_sentence, corpus, k=3):
    tokenized_corpus = [nltk.word_tokenize(sentence.lower()) for sentence in corpus]

    # Calculate BM25 scores
    bm25 = BM25Okapi(tokenized_corpus)
    tokenized_input_sentence = nltk.word_tokenize(input_sentence.lower())
    bm25_scores = bm25.get_scores(tokenized_input_sentence)

    # Sort by BM25 scores
    scored_sentences = list(zip(corpus, bm25_scores))
    scored_sentences.sort(key=lambda x: x[1], reverse=True)
    top_k_sentences = [sentence[0] for sentence in scored_sentences[:k]]
    return top_k_sentences

from alignscore import AlignScore
scorer = AlignScore(
    model='roberta-base',
    batch_size=32,
    ckpt_path='../../AlignScore/ckpts/AlignScore-large.ckpt',
    evaluation_mode='nli_sp',
    verbose=False,
    device='cuda:0',
)

k = 5

for i, x in enumerate(tqdm(data)):
    annotations = x['annotations']
    if annotations is None:
        continue
    for annotation in annotations:
        batch_texts = []
        batch_additional_texts = []
        for atomic_fact in annotation['model-atomic-facts']:
            text = atomic_fact['text']
            if len(text) == 0:
                continue
            additional_texts = [get_top(text, x[i], k=k) for x in all_additional_outputs]
            additional_texts = [' '.join(x) for x in additional_texts]

            scores = scorer.score(
                claims=[text] * len(additional_texts),
                contexts=additional_texts,
            )
            score = np.mean(scores)
            atomic_fact['consistency'] = score

    if i % 10 == 0:
        with jsonlines.open(save_path_0.replace(".jsonl", f"_consistency_alignscore.jsonl"), mode='w') as writer:
            for dictionary in data:
                writer.write(dictionary)

with jsonlines.open(save_path_0.replace(".jsonl", f"_consistency_alignscore.jsonl"), mode='w') as writer:
    for dictionary in data:
        writer.write(dictionary)


