import pandas as pd
import csv
import re
from tqdm import tqdm
import ast
import math


def clean_labels(label_data_list, sent_data_list):
    error = 0
    index_i = 0
    label_data_list_new = []
    sent_data_list_new = []
    for row in tqdm(label_data_list):
        label_list_new = []
        sent_list_new = []
        index_j =0
        if type(sent_data_list[index_i][0]) != str:
            index_i += 1
            label_list_new = [0] * 59
            sent_list_new = []
            label_data_list_new.append(label_list_new)
            sent_data_list_new.append([str(sent_list_new)])
            continue
        else:
            sent_data = parse_string_to_list(sent_data_list[index_i][0])

            if sent_data == None:
                index_i += 1
                label_list_new = [0] * 59
                sent_list_new = []
                label_data_list_new.append(label_list_new)
                sent_data_list_new.append([str(sent_list_new)])
                continue
            else:
                for item in row:
                    matched = False

                    if type(item) == int:
                        matched = True
                        label_data_list[index_i][index_j] = '0'

                    if item == 'X':
                        matched = True

                    # Rule 1: 整数形式的字符串
                    if not matched:
                        if ',' in item:
                            item_list = item.split(', ')
                            matched_temp = True
                            for item_temp in item_list:
                                try:
                                    int(item_temp[:-1])
                                except:
                                    matched_temp = False
                            if matched_temp == True:
                                matched = True
                                label_list_new.append(item)

                    if not matched:
                        if re.match(r'^\d{1,2}$', item):
                            number = int(item)
                            if 0 <= number <= 26:
                                matched = True
                                label_list_new.append(item)

                    if not matched:
                        if re.match(r'^\d{1,2}$', item[:-1]):
                            number = int(item[:-1])
                            if 0 <= number <= 26:
                                matched = True
                                label_list_new.append(str(number))

                    if not matched:
                        if item.startswith('Label:'):
                            remaining_part = item[len("Label:"):].strip()
                            if ',' in item:
                                item_list = remaining_part.split(', ')
                                matched_temp = True
                                for item_temp in item_list:
                                    try:
                                        int(item_temp[:-1])
                                    except:
                                        matched_temp = False
                                if matched_temp == True:
                                    matched = True
                                    label_list_new.append(remaining_part)

                            if not matched:
                                if re.match(r'^\d{1,2}$', remaining_part[:-1]):
                                    number = int(remaining_part[:-1])
                                    if 0 <= number <= 26:
                                        matched = True
                                        label_list_new.append(str(number))

                    if not matched:
                        list = parse_string_to_list(sent_data_list[index_i][0])
                        if list != None:
                            item_sent = list[index_j]
                            pattern = r"(?<!\d)[0-9]|(?<!\d)[1-2][0-9]|(?<!\d)3[0-9]"
                            matches = re.findall(pattern, item_sent)
                            if len(matches) == 0:
                                # data[index_i][index_j] = '0'
                                # sent_data_list[index_i][index_j] = 'X'
                                matched = True

                    # 如果不符合任何规则，抛出异常
                    if not matched:
                        list = parse_string_to_list(sent_data_list[index_i][0])
                        if list == None:
                            pass
                        else:
                            print('||', list[index_j], '||', row[index_j], '||')
                        error += 1

                    if matched == True:
                        if item != '0' and item != 0:
                            sent_list_new.append(sent_data[index_j])
                    else:
                        label_list_new.append('X')
                        sent_list_new.append('X')
                    index_j += 1

        if len(label_list_new) < 59:
            for _ in range(59 - len(label_list_new)):
                label_list_new.append('0')
        label_data_list_new.append(label_list_new)
        sent_data_list_new.append([str(sent_list_new)])
        index_i += 1

    print(error)

    return label_data_list_new, sent_data_list_new

def check_validity(lst):
    sign_dict = {}
    split_elements = set()
    error_number = []
    for item in lst:
        elements = str(item).split(', ')
        for element in elements:
            split_elements.add(element)
    for item in split_elements:
        if item in ['25', '26']:
            continue
        number = item[:-1]
        sign = item[-1]
        if number in sign_dict:
            if sign_dict[number] != sign:
                error_number.append(number)
        else:
            sign_dict[number] = sign
    if len(error_number) == 0:
        return True
    else:
        return error_number

def parse_string_to_list(string):
    try:
        string = re.sub(r"(?<=\[|,)\s*'(.*?)'\s*(?=,|\])", r'"\1"', string)
        result = ast.literal_eval(string)
        if isinstance(result, list):
            return result
        else:
            raise ValueError("The parsed result is not a list.")
    except (SyntaxError, ValueError) as e:
        print(f"Error parsing string: {e}")
        return None

def append_to_csv(file_path, row):
    with open(file_path, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(row)


filepaths_label = '/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC/XH_sent_label_part_all.csv'
data_label = pd.read_csv(filepaths_label, header=None).values.tolist()
filepaths_sent = '/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC/XH_cut_report_part_all.csv'
data_sent = pd.read_csv(filepaths_sent, header=None).values.tolist()

# data_label_new = []
# data_sent_new = []
#
# num = 0
# for i in range(len(data_sent)):
#     sent_item = data_sent[i]
#     label_item = data_label[i]
#     if type(sent_item[0]) != str:
#         data_sent_new.append([])
#         data_label_new.append(label_item)
#     else:
#         list_temp = parse_string_to_list(sent_item[0])
#         if list_temp != None:
#             data_sent_new.append(list_temp)
#             data_label_new.append(label_item)
#         else:
#             num += 1
#             print(num)
#             a = sent_item[0]
#             b = label_item

label_data_list_new, sent_data_list_new = clean_labels(data_label, data_sent)

# for i in tqdm(range(0, len(data_label))):
#     data_label_list = data_label[i]
#     temp = [element for element in data_label_list if element != '0']
#     temp = [element for element in temp if element != 0]
#     data_sent_list = parse_string_to_list(data_sent[i][0])
#     data_label_temp = []
#     data_sent_temp = []
#     for j in range(0, len(data_label_list)):
#         label = data_label_list[j]
#         if label != 'X' and label != '0' and label != 0:
#             data_label_temp.append(label)
#             data_sent_temp.append(data_sent_list[j])
#         if label == '0' or label == 0:
#             data_label_temp.append(label)
#     if len(data_label_temp) < 59:
#         for _ in range(59 - len(data_label_temp)):
#             data_label_temp.append('0')
#     data_label_new.append(data_label_temp)
#     quoted_list = [f"'{item}'" for item in data_sent_temp]
#     big_string = ', '.join(quoted_list)
#     big_string_with_brackets = '[' + big_string + ']'
#     temp = [element for element in data_label_temp if element != '0']
#     temp = [element for element in temp if element != 0]
#     if len(parse_string_to_list(big_string_with_brackets)) != len(temp):
#         raise ValueError("数据保存错误")
#     data_sent_new.append([big_string_with_brackets])

# error = 0
# error_number = {str(i): 0 for i in range(1, 25)}
# for i in tqdm(range(0, len(data_label))):
#     data_label_list = data_label[i]
#     temp = [element for element in data_label_list if element != '0']
#     temp = [element for element in temp if element != 0]
#     for item in temp:
#         if type(item) != str:
#             print('fuck')
    # data_sent_list = parse_string_to_list(data_sent[i][0])
    # validity = check_validity(temp)
    # if validity != True:
    #     for num in validity:
    #         error_number[num] += 1
# print(error)
# print(error_number)

# data_num = 5000 * 2 + 50
# data_label_new = data_label[:data_num]
# data_sent_new = data_sent[:data_num]
#
for sample in tqdm(label_data_list_new):
    append_to_csv('/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC/XH_label_data_list_new.csv', sample)

for sample in tqdm(sent_data_list_new):
    append_to_csv('/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC/XH_sent_data_list_new.csv', sample)
