import pickle
import argparse
import pandas as pd
from tqdm import tqdm
import Levenshtein
import os


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default=None)
    parser.add_argument('--test_data_path', type=str, default=None)
    parser.add_argument('--MIMIC_data_path', type=str, default=None)
    parser.add_argument('--report_corpus_path', type=str, default=None)
    parser.add_argument('--num_chunks', type=int, default=None)
    args = parser.parse_args()
    return args

def split_list_into_chunks(lst, num_chunks):
    chunk_size = len(lst) // num_chunks
    chunks = [lst[i * chunk_size:(i + 1) * chunk_size] for i in range(num_chunks)]
    remainder = len(lst) % num_chunks
    for i in range(remainder):
        chunks[i].append(lst[num_chunks * chunk_size + i])
    return chunks


def match_report(A, B):
    result = [[] for _ in range(len(A))]
    for i in tqdm(range(len(A)), desc="Retrieving"):
        sublist = A[i]
        min_distance = float('inf')
        best_match = 'error'
        sublist_str = ' '.join(sublist)
        for string in B:
            distance = Levenshtein.distance(sublist_str, string)
            if distance < min_distance:
                min_distance = distance
                best_match = string
        result[i].append(best_match)
    return result


if __name__ == '__main__':

    args = parse_args()

    with open(args.report_corpus_path, "rb") as f:
        print(f"Loading captions from {args.report_corpus_path}")
        path2sent, path2label, to_remove, label_ids = pickle.load(f)
    sent_list = []
    for key, value in path2sent.items():
        sent_list.append(value)
    report_corpus = [list(t) for t in set(tuple(sublist) for sublist in sent_list)]
    report_corpus_chunk = split_list_into_chunks(report_corpus, args.num_chunks)
    chunk_len_list = [len(s) for s in report_corpus_chunk]
    del path2sent, path2label, to_remove, label_ids, sent_list, report_corpus_chunk

    SimR_all = []
    for i in range(args.num_chunks):
        pickle_file = f'/mnt/nvme_share/wuwl/project/CARZero-main/data/output/retrieval_based_report_generation/{args.model_name}/SimR_{i+1}.pkl'
        with open(pickle_file, 'rb') as file:
            data = pickle.load(file)
        if i == 0:
            SimR_all = data
        else:
            for idx in range(len(data)):
                best_key = list(SimR_all[idx].keys())[0]
                best_value = SimR_all[idx][best_key]
                new_key = list(data[idx].keys())[0]
                new_value = data[idx][new_key]
                if new_value > best_value:
                    abs_index_key = new_key + sum(chunk_len_list[:i])
                    SimR_all[idx] = {abs_index_key: new_value}

    retrieval_sent_list = []
    for idx in range(len(SimR_all)):
        best_key = list(SimR_all[idx].keys())[0]
        retrieval_sent_list.append(report_corpus[best_key])

    report_oral = pd.read_csv(args.MIMIC_data_path)
    filtered_df = report_oral[report_oral["Frontal/Lateral"] == 'Frontal']
    filtered_indices = filtered_df.index
    report_oral = filtered_df.reset_index(drop=True)["Report Impression"].tolist()

    retrieval_report_path = f'/mnt/nvme_share/wuwl/project/CARZero-main/data/output/retrieval_based_report_generation/{args.model_name}/retrieval_report.csv'
    if not os.path.exists(retrieval_report_path):
        print('retrieval report 不存在，开始检索')
        retrieval_report = match_report(retrieval_sent_list, report_oral)
        retrieval_report = pd.DataFrame(retrieval_report)
        retrieval_report.to_csv(retrieval_report_path, index=False)
        print('检索完成，retrieval report 已保存')
    else:
        print('retrieval report 已存在，检索完成')