import os
import argparse
from tqdm import tqdm
import pandas as pd
import re
import csv
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--split_id', type=str, default='1')
    parser.add_argument('--model_name_or_path', type=str, default=None)
    args = parser.parse_args()
    return args

def prompt_data(sent):
    prompt = """
Assume you are an experienced radiologist. Help me split a medical diagnostic conclusion report into sentences. Please split the sentences where necessary and process the text according to the following rules:

Remove redundant periods, such as '. .'. For example:
Input: feedingtube has been removed. nasogastric tube is coiled in the stomach region. . interval decrease in the right-sided. stable appearance of the left lung.
Output: ['feedingtube has been removed.', 'nasogastric tube is coiled in the stomach region.', 'interval decrease in the right-sided.', 'stable appearance of the left lung.']

For ordered structures like '1.', '2.', keep the ordered structure and the sentence that follows together without splitting. For example:
Input: 1. low lung volumes and mild pulmonary vascular congestion is unchanged. 2. new small right fissural pleural effusion .
Output: ['1. low lung volumes and mild pulmonary vascular congestion is unchanged.', '2. new small right fissural pleural effusion.']

Remove information unrelated to the case, such as treatment recommendations or notifications to relevant personnel. For example:
Input: recommend chest ct with intravenous contrast for further assessment. dr. communicated the above results to dr. at 8:55 am on by telephone. findings were relayed to dr.
Output: (remove these parts)

Remove short sentences that do not contain case information, such as comparisons to previous case results, and any symptoms were not mentioned. For example:
Input: no other relevant change. otherwise little change. comparison is made to prior study from at 4:05 a.m. ap chest compared to and 5:05 p.m.
Output: (remove these parts)

After processing the text, split the remaining information into appropriate sentences and store each sentence in a list.

Example 1:
Input:
recommend chest ct with intravenous contrast for further assessment.dr. communicated the above results to dr. at 8:55 am on by telephone.findings were relayed to dr. 1. low lung volumes and mild pulmonary vascular congestion is unchanged. 2. new small right fissural pleural effusion. no other relevant change. otherwise little change. comparison is made to prior study from at 4:05 a.m. ap chest compared to and 5:05 p.m.
Output:
['1. low lung volumes and mild pulmonary vascular congestion is unchanged.', '2. new small right fissural pleural effusion.']

Example 2:
Input:
right upper lobe pneumonia or mass . however given right hilar fullness a mass resulting in post-obstructive pneumonia is within the differential. recommend chest ct with intravenous contrast for further assessment.dr. communicated the above results to dr. at 8:55 am on by telephone.
Output:
['right upper lobe pneumonia or mass .', 'however given right hilar fullness a mass resulting in post-obstructive pneumonia is within the differential.']

Now, please process the input text according to these rules and store each sentence in a list. Notice: you only need to output the list, do not output any reason!
"""
    messages = [
        {"role": "system", "content": prompt},
        {"role": "user", "content": sent}
    ]
    return messages

def inference(messages, model, tokenizer, max_length):
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)
    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]
    outputs = model.generate(
        input_ids,
        max_new_tokens=max_length,
        eos_token_id=terminators,
        do_sample=False,
        temperature=0.6,
        top_p=0.9,
    )
    response = outputs[0][input_ids.shape[-1]:]
    response = tokenizer.decode(response, skip_special_tokens=True)

    return response


def split_csv(data, output_prefix, chunk_size=20000):
    total_rows = len(data)
    num_chunks = (total_rows // chunk_size) + 1
    for i in range(num_chunks):
        start_row = i * chunk_size
        end_row = start_row + chunk_size
        chunk = data[start_row:end_row]
        output_file = f"{output_prefix}_part_{i + 1}.csv"
        chunk.to_csv(output_file, index=False)
        print(f"Saved {output_file}")


def max_length_string(strings):
    if not strings:
        return 0
    max_length = max(len(s) for s in strings)
    return max_length

def append_to_csv(file_path, row):
    with open(file_path, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(row)

def generate():
    args = parse_args()

    filepaths = f'/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC/report_part_{args.split_id}.csv'
    data = pd.read_csv(filepaths)
    data_list = []
    for idx, row in tqdm(data.iterrows(), total=data.shape[0]):
        captions = ""
        if type(row['Report Impression']) == str:
            captions += row['Report Impression']
        captions = captions.replace("\n", " ")
        data_list.append(captions)

    max_length = max_length_string(data_list) + 800

    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True, torch_dtype=torch.float16)

    model.cuda()
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        model_max_length=max_length,
        use_fast=False,
        trust_remote_code=True
    )

    file_path = f'/mnt/nvme_share/wuwl/project/CARZero/Dataset/MIMIC/cut_report_part_{args.split_id}.csv'
    if os.path.exists(file_path):
        history = pd.read_csv(file_path)
        cache = len(history)
    else:
        cache = 0

    for idx in tqdm(range(0, len(data_list))):
        if idx < cache:
            continue
        sample = data_list[idx]
        data_sent = prompt_data(sample)
        output = inference(data_sent, model, tokenizer, max_length)
        append_to_csv(file_path, [output])


if __name__ == '__main__':
    generate()
