import os
import argparse
import pandas as pd
import re
import csv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import ast
import math
import numpy as np
from tqdm import tqdm

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--split_id', type=str, default='1')
    parser.add_argument('--model_name_or_path', type=str, default=None)
    args = parser.parse_args()
    return args

def inference(messages, model, tokenizer):
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)
    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]
    outputs = model.generate(
        input_ids,
        max_new_tokens=256,
        eos_token_id=terminators,
        do_sample=False,
        temperature=0.6,
        top_p=0.9,
    )
    response = outputs[0][input_ids.shape[-1]:]
    response = tokenizer.decode(response, skip_special_tokens=True)

    return response

def append_to_csv(file_path, row):
    with open(file_path, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(row)

def parse_string_to_list(string):
    try:
        string = re.sub(r"(?<=\[|,)\s*'(.*?)'\s*(?=,|\])", r'"\1"', string)
        result = ast.literal_eval(string)
        if isinstance(result, list):
            return result
        else:
            raise ValueError("The parsed result is not a list.")
    except (SyntaxError, ValueError) as e:
        print(f"Error parsing string: {e}")
        return None

def prompt_data(sent):
    prompt = """
Assume you are an experienced radiologist. Please help me rewrite the following medical diagnostic report sentences, adhering to the following requirements:

Rewriting Requirement:
    (1)Semantic Clarity: The meaning of the sentence is clear and unambiguous.
    (2)Avoid Complex Negative Sentences: The structure of negative sentences is simplified to avoid confusion.
    (3)Avoid Double Negations: Double negatives are not used, as they can complicate understanding.
    (4)Use the 'There be' sentence pattern: Use the 'There be' sentence pattern whenever possible to enhance directness and clarity, for example: There is xxx / There is no xxx(There is no signs of xxx).
    (5)Avoid Inversion: Inverted sentence structures are not used to prevent complexity in understanding, rewrite format:'No xxx is present/show' to format:'There is no signs of xxx'.
    (6)Accuracy of Professional Terminology: Precise and consistent medical terminology is used to maintain accuracy.

Examples:

Input: The patient exhibits signs of acute appendicitis with localized tenderness in the lower right quadrant.
Output: There is symptoms of acute appendicitis, with localized tenderness observed in the lower right abdominal quadrant.

Input: No consolidation or bone fracture is present.
Output: There is no signs of consolidation or bone fracture.

Input: No acute cardio pulmonary process.
Output: There is no signs of acute cardiac or pulmonary issues.

Please process the input medical diagnostic report sentences according to the above prompt.
Notice: you only need to output the rewrite sentence, do not output any other thing!
"""
    messages = [
        {"role": "system", "content": prompt},
        {"role": "user", "content": sent}
    ]
    return messages

def generate():
    args = parse_args()

    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True, torch_dtype=torch.float16)
    model.cuda()
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        model_max_length=1500,
        use_fast=False,
        trust_remote_code=True
    )

    filepaths = f'/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC_new/cut_report_part_{args.split_id}.csv'
    data_sent = pd.read_csv(filepaths, header=None).values.tolist()
    data_sent_list = []
    for sent_sample in tqdm(data_sent, desc='Loading chopped sentences'):
        data_sent_list.append(parse_string_to_list(sent_sample[0])[:50])

    filepaths = f'/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC_new/sent_label_part_{args.split_id}.csv'
    data_label = pd.read_csv(filepaths, header=None).values.tolist()
    data_label_list = []
    for label_sample in tqdm(data_label, desc='Loading sentences labels'):
        temp = [element for element in label_sample[:50] if element != '0' and element != 0]
        data_label_list.append(temp)

    sent_output_path = f'/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC_new/LLM_cut_report_part_{args.split_id}.csv'
    label_output_path = f'/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC_new/LLM_sent_label_part_{args.split_id}.csv'


    if os.path.exists(sent_output_path):
        history = pd.read_csv(sent_output_path, header=None)
        cache = len(history)
    else:
        cache = 0

    for idx in tqdm(range(0, len(data_sent_list))):
        if idx < cache:
            continue
        sent_sample = data_sent_list[idx]
        label_sample = data_label_list[idx]
        sent_output = []
        label_output = []
        if len(sent_sample) == 0:
            for _ in range(50):
                label_output.append('0')
        else:
            for i in range(0, len(sent_sample)):
                data_sent_item = sent_sample[i]
                data_sent_item = prompt_data(data_sent_item)
                data_sent_item = inference(data_sent_item, model, tokenizer)
                data_label_item = str(label_sample[i])
                if '25' not in data_label_item and '26' not in data_label_item:
                    sent_output.append(data_sent_item)
                    label_output.append(data_label_item)

        if len(label_output) < 50:
            for _ in range(50 - len(label_output)):
                label_output.append('0')
        append_to_csv(sent_output_path, [str(sent_output)])
        append_to_csv(label_output_path, label_output)


if __name__ == '__main__':
    generate()