import pandas as pd
import numpy as np
import os
import re
import ast
import tqdm


def parse_labels(element):
    return [label.strip() for label in element.split(',')]

def check_validity(label_set):
    seen = {}
    conflicts = set()
    for label in label_set:
        if label[-1] in ('+', '-'):
            num = int(label[:-1])
            sign = label[-1]
            if num in seen and seen[num] != sign:
                conflicts.add(f"{num}+")
                conflicts.add(f"{num}-")
            else:
                seen[num] = sign
    return conflicts

def validate_check(input_array):
    label_set = set()
    for element in input_array:
        labels = parse_labels(element)
        label_set.update(labels)
    conflicts = check_validity(label_set)
    if conflicts:
        new_array = []
        for element in input_array:
            labels = parse_labels(element)
            modified_labels = ['25' if label in conflicts else label for label in labels]
            new_array.append(', '.join(modified_labels))
        return new_array
    else:
        return input_array

def clean_label(sent_temp, label_sample):
    sent_list_new = []
    label_list_new = []
    label_sample_temp = [element for element in label_sample if element != '0' and element != 0]
    for i in range(len(label_sample_temp)):
        item = str(label_sample_temp[i])
        matched = False
        if not matched:
            if ',' in item:
                item_list = item.split(', ')
                matched_temp = True
                for item_temp in item_list:
                    try:
                        int(item_temp[:-1])
                    except:
                        matched_temp = False
                if matched_temp == True:
                    matched = True
                    label_list_new.append(item)
        if not matched:
            if re.match(r'^\d{1,2}$', item):
                number = int(item)
                if 0 <= number <= 26:
                    matched = True
                    label_list_new.append(item)
        if not matched:
            if re.match(r'^\d{1,2}$', item[:-1]):
                number = int(item[:-1])
                signal = item[-1]
                if 0 <= number <= 26:
                    matched = True
                    label_list_new.append(str(number) + signal)
        if not matched:
            if item.startswith('Label:'):
                remaining_part = item[len("Label:"):].strip()
                if ',' in item:
                    item_list = remaining_part.split(', ')
                    matched_temp = True
                    for item_temp in item_list:
                        try:
                            int(item_temp[:-1])
                        except:
                            matched_temp = False
                    if matched_temp == True:
                        matched = True
                        label_list_new.append(remaining_part)
                if not matched:
                    if re.match(r'^\d{1,2}$', remaining_part[:-1]):
                        number = int(remaining_part[:-1])
                        if 0 <= number <= 26:
                            matched = True
                            label_list_new.append(str(number))
        if matched == False:
            sent_list_new.append('X')
            label_list_new.append('X')
        else:
            sent_list_new.append(sent_temp[i])
    return sent_list_new, label_list_new

def merge_csv_files(directory, input_file, output_file):
    # 初始化一个空的DataFrame
    merged_df = pd.DataFrame()

    # 遍历指定目录下的csv文件
    for i in range(1, 13):  # 假设文件名为1到12的数字
        file_name = f"{input_file}_{i}.csv"
        file_path = os.path.join(directory, file_name)

        # 检查文件是否存在
        if os.path.exists(file_path):
            # 读取csv文件
            df = pd.read_csv(file_path, header=None)
            # 将读取的DataFrame追加到merged_df中
            merged_df = pd.concat([merged_df, df], ignore_index=True)
        else:
            print(f"文件 {file_name} 不存在，跳过合并。")

    # 将合并后的DataFrame保存为新的csv文件
    file_path_output = os.path.join(directory, output_file)
    merged_df.to_csv(file_path_output, index=False, header=False)
    print(f"合并后的文件已保存为 {file_path_output}")

def parse_string_to_list(string):
    try:
        string = re.sub(r"(?<=\[|,)\s*'(.*?)'\s*(?=,|\])", r'"\1"', string)
        result = ast.literal_eval(string)
        if isinstance(result, list):
            return result
        else:
            raise ValueError("The parsed result is not a list.")
    except (SyntaxError, ValueError) as e:
        print(f"Error parsing string: {e}")
        return None

# 使用示例
# 假设你的csv文件都存放在'path_to_csv_files'目录下
# 合并后的文件将保存为'merged_files.csv'
# merge_csv_files('/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC/', 'XH_cut_report_part', 'XH_cut_report_part_all.csv')
# merge_csv_files('/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC/', 'XH_sent_label_part', 'XH_sent_label_part_all.csv')


sent_original_path = '/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC/cut_report_part_all.csv'
label_original_path = '/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC/sent_label_part_all_new.csv'

sent_original = pd.read_csv(sent_original_path, header=None).reset_index(drop=True).values.tolist()
label_original = pd.read_csv(label_original_path, header=None).reset_index(drop=True).values.tolist()

sent_original_list = []
for sent_sample in tqdm.tqdm(sent_original, desc='Loading chopped sentences'):
    sent_original_list.append(parse_string_to_list(sent_sample[0]))

label_original_list = []
for label_sample in tqdm.tqdm(label_original, desc='Loading sentences labels'):
    temp = [element for element in label_sample if element != '0' and element != 0]
    label_original_list.append(temp)

sent_original_list_new = []
label_original_list_new = []
for i in tqdm.tqdm(range(len(sent_original_list)), desc='clean data'):
    sent_sample = sent_original_list[i]
    label_sample = label_original_list[i]
    sent_temp, label_temp = clean_label(sent_sample, label_sample)
    if sent_sample != sent_temp:
        filtered_sent_temp = [element for element in sent_temp if element != 'X']
        filtered_label_temp = [element for element in label_temp if element != 'X']
        sent_original_list_new.append(filtered_sent_temp)
        label_original_list_new.append(filtered_label_temp)
    else:
        sent_original_list_new.append(sent_temp)
        label_original_list_new.append(label_temp)

sent_original_list_new_2 = []
for sent in sent_original_list_new:
    sent_original_list_new_2.append([str(sent)])

label_original_list_new_2 = []
for label in label_original_list_new:
    if len(label) < 59:
        for _ in range(59 - len(label)):
            label.append('0')
    label_original_list_new_2.append(label)

sent_XH_path = '/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC/XH_cut_report_part_all.csv'
label_XH_path = '/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC/XH_sent_label_part_all.csv'

sent_XH = pd.read_csv(sent_XH_path, header=None).reset_index(drop=True).values.tolist()
label_XH = pd.read_csv(label_XH_path, header=None).reset_index(drop=True).values.tolist()

sent_XH_list = []
label_XH_list = []
for i in tqdm.tqdm(range(len(sent_XH)), desc='Loading XH data'):
    sent_sample = sent_XH[i]
    label_sample = label_XH[i]
    if type(sent_sample[0]) != str:
        sent_XH_list.append([])
        label_XH_list.append([])
    else:
        sent_temp = parse_string_to_list(sent_sample[0])
        if sent_temp == None:
            sent_XH_list.append([])
            label_XH_list.append([])
        else:
            sent_temp, label_temp = clean_label(sent_temp, label_sample)
            filtered_sent_temp = [element for element in sent_temp if element != 'X']
            filtered_label_temp = [element for element in label_temp if element != 'X']
            filtered_label_temp = validate_check(filtered_label_temp)
            sent_XH_list.append(filtered_sent_temp)
            label_XH_list.append(filtered_label_temp)

sent_new = []
label_new = []

for i in tqdm.tqdm(range(len(sent_XH_list)), desc='merge data'):
    sent_all_temp = sent_original_list_new[i] + sent_XH_list[i]
    label_all_temp = label_original_list_new[i] + label_XH_list[i]
    if len(label_all_temp) < 50:
        for _ in range(50 - len(label_all_temp)):
            label_all_temp.append('0')
    sent_new.append([str(sent_all_temp)])
    label_new.append(label_all_temp)


pd.DataFrame(sent_new).to_csv('/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC/cut_report_part_final.csv', index=False, header=False)
pd.DataFrame(label_new).to_csv('/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC/sent_label_part_final.csv', index=False, header=False)