import os
import argparse
from tqdm import tqdm
import pandas as pd
import re
import csv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import ast
import math


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--split_id', type=str, default='1')
    parser.add_argument('--model_name_or_path', type=str, default=None)
    args = parser.parse_args()
    return args

def find_and_remove_indices(lst, elements):
    result = []
    list_temp = []
    for item in lst:
        list_temp.append(item)
    for element in elements:
        index = list_temp.index(element)
        result.append(index)
        list_temp[index] = None
    return result

def check_validity(lst):
    sign_dict = {}
    split_elements = set()
    invalid_number = []
    invalid_elements_oral = []
    for item in lst:
        elements = item.split(', ')
        for element in elements:
            split_elements.add(element)
    for item in split_elements:
        if item in ['25', '26']:
            continue
        number = item[:-1]
        sign = item[-1]
        if number in sign_dict:
            if sign_dict[number] != sign:
                invalid_number.append(number)
        else:
            sign_dict[number] = sign
    for original_item in lst:
        for number in invalid_number:
            if number in original_item:
                invalid_elements_oral.append(original_item)
                break
    if len(invalid_elements_oral) == 0:
        return True
    else:
        return invalid_elements_oral


def prompt_1_data(sent):
    prompt = """
Assume you are an experienced radiologist. Please help me process the following medical diagnostic report sentences, adhering to the following requirements:

1.Rewriting Requirement: If the sentence content is related to specific medical parts or symptoms, please rewrite the sentence to ensure:
    (1)Semantic Clarity: The meaning of the sentence is clear and unambiguous.
    (2)Avoid Complex Negative Sentences: The structure of negative sentences is simplified to avoid confusion.
    (3)Avoid Double Negations: Double negatives are not used, as they can complicate understanding.
    (4)Use Active Voice: The active voice is employed whenever possible to enhance directness and clarity.
    (5)Avoid Inversion: Inverted sentence structures are not used to prevent complexity in understanding, rewrite format:'No xxx is present/show' to format:'The patient shows no signs of xxx'.
    (6)Accuracy of Professional Terminology: Precise and consistent medical terminology is used to maintain accuracy.
2.Marking Requirement: If the sentence content is unrelated to any specific medical parts or symptoms, such as when mentioning comparisons or general descriptions (e.g., "Compared to chest radiographs since the most recent one"), please output the letter "X" to mark this situation.

Examples:

Input: The patient exhibits signs of acute appendicitis with localized tenderness in the lower right quadrant.
Output: The patient shows symptoms of acute appendicitis, with localized tenderness observed in the lower right abdominal quadrant.

Input: No consolidation or bone fracture is present.
Output: The patient shows no signs of consolidation or bone fracture.

Input: No acute cardio pulmonary process.
Output: The patient shows no signs of acute cardiac or pulmonary issues.

Input: Compared to chest radiographs since the most recent one, there are no significant changes noted.
Output: X

Please process the input medical diagnostic report sentences according to the above prompt.
Notice: you only need to output the rewrite sentence or "X" , do not output any other thing!
"""
    messages = [
        {"role": "system", "content": prompt},
        {"role": "user", "content": sent}
    ]
    return messages
def prompt_2_data(sent):
    prompt = """
Assume you are an experienced radiologist. Help me identify the correct medical condition label for the given radiology report sentence. Below are the medical condition labels with corresponding numbers and their medical descriptions:

1.Atelectasis: Lung tissue exhibits signs of partial or complete atelectasis, with decreased lung volume and increased localized radiographic density.
2.Pleural Effusion: There is an abnormal accumulation of fluid within the pleural cavity, which is evident on X-ray imaging as a blunted costophrenic angle or the presence of a fluid level.
3.Pneumothorax: There is evidence of free air within the thoracic cavity, resulting in partial or complete atelectasis. This is characterized by a retracted lung edge and an area devoid of lung markings.
4.Cardiomegaly: The cardiac silhouette is enlarged, indicating a size beyond the normal range. Limits the heart only, vascular is not in consideration.
5.Opacity: A lung region demonstrates increased radiopacity, potentially suggestive of inflammatory changes, a tumor, or hemorrhage.
6.Pneumonia: Patchy to diffuse lung infiltrates are observed, frequently associated with air bronchograms.
7.Pulmonary Mass: A lung mass, either well-circumscribed or poorly defined, is present, typically measuring over 3 cm in diameter.
8.Edema: Interstitial or alveolar fluid accumulation in the lungs is evident, characterized by increased and indistinct lung markings, a common finding in cardiogenic pulmonary edema.
9.Lung Nodule: A round or oval opacity within the lung, measuring less than 3 cm in diameter.
10.Lung Infiltration: Patchy or reticular opacities are noted within the lung tissue, suggestive of inflammatory processes or other infiltrative conditions.
11.Fibrosis: Interstitial lung thickening and fibrosis are present, exhibiting a reticular pattern and a honeycomb-like appearance on imaging studies.
12.Emphysema: Overinflation of the lungs with alveolar destruction is observed, manifesting as increased lung lucency and expanded lung fields on the chest X-ray.
13.Pleural Thickening: Pleural layer thickening is evident, appearing as broadened pleural shadows on imaging, often attributed to chronic inflammation or fibrosis.
14.Hernia: Internal organ protrusion through either normal or abnormal openings is observed; in the case of a diaphragmatic hernia, this may present as an abnormal position and contour of the diaphragm on X-ray imaging.
15.Consolidation: The lung alveoli are opacified with fluid or solid material, manifesting as regions of increased density with indistinct borders on imaging. the lung tissue exhibits a solidified appearance, a common finding in cases of pneumonia.
16.Bone Fracture: A discontinuity within the bone structure is observed on X-ray, characterized by a disrupted cortical bone and the presence of fracture lines.
17.Enlarged Cardiomediastinum: An enlargement of the mediastinal shadow or cardiomediastinal silhouette is noted.
18.Pleural Other: Other pleural abnormalities, such as pleural calcifications or plaques, are present, exhibiting specific imaging features that are indicative of the underlying condition.
19.Lung Lesion: An encompassing term for a variety of abnormal imaging findings within the lung, encompassing nodules, masses, infiltrates, and other anomalies.
20.Support Devices: Visualized on imaging are various medical devices, including catheters, stents, prosthetic heart valves, and other implanted or inserted medical apparatus.
21.Abnormal Lesion: A non-specific term denoting any abnormal imaging findings within the lungs, which may encompass a range of presentations such as nodules, masses, opacities, or other anomalies.
22.Lung Granuloma: A small pulmonary nodule, usually measuring less than 3 cm in diameter, frequently exhibiting calcification.
23.Calcified Granuloma: A calcified granuloma is observed, manifesting as a high-density nodule on imaging studies.
24.Tissue Calcification: Calcifications within soft tissue are noted, appearing as areas of increased density on imaging, indicative of calcified spots or plaques.
25.No Mention: None of the above symptoms are mentioned or related, cannot exist with any other labels at the same time.

Examples:

Sentence: there is no focal consolidation pleural effusion or pneumothorax.
Label: 15, 2, 3

Sentence: bilateral nodular opacities that most likely represent nipple shadows.
Label: 9

Sentence: chronic deformity of the posterior left sixth and seventh ribs are noted.
Label: 16

Sentence: the patient shows no signs of free air below the right hemidiaphragm.
Label: 3

Sentence: the imaged upper abdomen shows no remarkable findings.
Label: 25

Sentence: the patient's overall condition is normal.
Label: 25

Sentence: central vascular engorgement.
Label: 25

Special Note:
If the sentence mentions a condition (whether positive or negative), use the corresponding label.
If the sentence describes multiple conditions, except 25(No Mention), output multiple labels, separated by commas ",".
Remember, 25(No Mention) and other labels cannot exist at the same time.

In case you forgot, let me repeat these labels:
1. Atelectasis
2. Pleural Effusion
3. Pneumothorax
4. Cardiomegaly
5. Opacity
6. Pneumonia
7. Pulmonary Mass
8. Edema
9. Lung Nodule
10. Lung Infiltration
11. Fibrosis
12. Emphysema
13. Pleural Thickening
14. Hernia
15. Consolidation
16. Bone Fracture
17. Enlarged Cardiomediastinum
18. Pleural Other
19. Lung Lesion
20. Support Devices
21. Abnormal Lesion
22. Lung Granuloma
23. Calcified Granuloma
24. Tissue Calcification
25. No Mention

Now, please select the most appropriate label for the following sentence and output only the corresponding number(s).
Notice: you only need to output the label (pure numbers), do not output anything else!
"""
    messages = [
        {"role": "system", "content": prompt},
        {"role": "user", "content": sent}
    ]
    return messages



def inference(messages, model, tokenizer):
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)
    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]
    outputs = model.generate(
        input_ids,
        max_new_tokens=256,
        eos_token_id=terminators,
        do_sample=False,
        temperature=0.6,
        top_p=0.9,
    )
    response = outputs[0][input_ids.shape[-1]:]
    response = tokenizer.decode(response, skip_special_tokens=True)

    return response


def split_csv(data, output_prefix, chunk_size=20000):
    total_rows = len(data)
    num_chunks = (total_rows // chunk_size) + 1
    for i in range(num_chunks):
        start_row = i * chunk_size
        end_row = start_row + chunk_size
        chunk = data[start_row:end_row]
        output_file = f"{output_prefix}_part_{i + 1}.csv"
        chunk.to_csv(output_file, index=False, header=None)
        print(f"Saved {output_file}")

def append_to_csv(file_path, row):
    with open(file_path, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(row)

def parse_string_to_list(string):
    try:
        string = re.sub(r"(?<=\[|,)\s*'(.*?)'\s*(?=,|\])", r'"\1"', string)
        result = ast.literal_eval(string)
        if isinstance(result, list):
            return result
        else:
            raise ValueError("The parsed result is not a list.")
    except (SyntaxError, ValueError) as e:
        print(f"Error parsing string: {e}")
        return None


def generate():
    args = parse_args()

    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True, torch_dtype=torch.float16)
    model.cuda()
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        model_max_length=1500,
        use_fast=False,
        trust_remote_code=True
    )

    # filepaths = f'/mnt/nvme_share/wuwl/project/CARZero/Dataset/MIMIC/cut_report_part_{args.split_id}.csv'
    filepaths = f'/mnt/nvme_share/wuwl/project/CARZero/Dataset/MIMIC/cut_report_part_all.csv'
    data_sent = pd.read_csv(filepaths, header=None)
    data_sent = data_sent.values.tolist()
    data_sent_list = []
    for sample in data_sent:
        output_list = parse_string_to_list(sample[0])
        data_sent_list.append(output_list)

    for sample in data_sent_list:
        if type(sample) == None:
            raise ValueError("数据解析错误")

    # filepaths = f'/mnt/nvme_share/wuwl/project/CARZero/Dataset/MIMIC/sent_label_part_{args.split_id}.csv'
    filepaths = f'/mnt/nvme_share/wuwl/project/CARZero/Dataset/MIMIC/sent_label_part_all.csv'
    data_label = pd.read_csv(filepaths, header=None)
    data_label = data_label.values.tolist()

    # file_path = f'/mnt/nvme_share/wuwl/project/CARZero/Dataset/MIMIC/sent_label_part_modified_{args.split_id}.csv'
    file_path = f'/mnt/nvme_share/wuwl/project/CARZero/Dataset/MIMIC/sent_label_part_all_new.csv'
    if os.path.exists(file_path):
        history = pd.read_csv(file_path, header=None)
        cache = len(history)
    else:
        cache = 0

    for idx in tqdm(range(0, len(data_sent_list))):
        if idx < cache:
            continue
        sample = data_sent_list[idx]
        label_list = data_label[idx]
        temp = [element for element in label_list if element != '0']
        temp = [element for element in temp if element != 0]
        validity = check_validity(temp)
        label_sent = []
        if '输入超出范围' in temp:
            temp = ['26' if x == '输入超出范围' else x for x in temp]
        if len(sample) == 0:
            for _ in range(59):
                label_sent.append('0')
        else:
            if validity == True:
                label_sent += temp
            else:
                error_index = find_and_remove_indices(temp, validity)

                # todo 直接把矛盾标签置为'26'
                for index in error_index:
                    temp[index] = '26'
                label_sent += temp

                # todo 对矛盾标签样本重新打标签
                # error_sent = []
                # for index in error_index:
                #     data_sent = sample[index]
                #     data_sent = prompt_1_data(data_sent)
                #     data_sent = inference(data_sent, model, tokenizer)
                #     error_sent.append(data_sent)
                # for sent in error_sent:
                #     data_sent = prompt_1_data(data_sent)
                #     data_sent = inference(data_sent, model, tokenizer)
                #     output_temp = prompt_2_data(data_sent)
                #     output_sent = inference(output_temp, model, tokenizer)
            if len(temp) != len(label_sent):
                raise ValueError("标签数量错误")
            if len(label_sent) < 59:
                for _ in range(59 - len(label_sent)):
                    label_sent.append('0')
        append_to_csv(file_path, label_sent)


if __name__ == '__main__':
    generate()
