# %%
import sys
import os
import json
from transformers import AutoTokenizer
import random
from tqdm import tqdm
from tqdm.contrib.concurrent import thread_map
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

os.environ["TOKENIZERS_PARALLELISM"] = "true"

batch_size = 1000
max_workers = 40
max_length = 4096

from track.track_MIMIC.patient import Patient

load_path = os.path.join(
    os.path.dirname(__file__),
    "..",
    "data",
    "json_disch_30B_A3B_few_shot_radio_30B_A3B_few_shot",
)

save_path = os.path.join(
    os.path.dirname(__file__),
    "..",
    "data",
    "env"
)

train_filename = f'trainset_{max_length}.jsonl'
val_filename = f'valset_{max_length}.jsonl'

if not os.path.exists(save_path):
    os.makedirs(save_path)

train_save_path = os.path.join(save_path, train_filename)
val_save_path = os.path.join(save_path, val_filename)

local_model = "<your_model_path>"

train_split_ratio = 0.9
val_split_ratio = 1 - train_split_ratio

tokenizer = AutoTokenizer.from_pretrained(
    local_model,
    trust_remote_code=True,
)

# %%
def format(pred):
    system_prompt = "You are a medical prediction assistant tasked with forecasting the potential results of a patient's subsequent medical examination based on their existing data. Representative examples are provided below.\n\n"
    
    system_prompt += 'Information to Predict:\n'
    system_prompt += '<few-shot example prompt for RadiologyEvent>'
    system_prompt += "Your response should be:\n"
    system_prompt += '<few-shot example response for RadiologyEvent>'
    
    system_prompt += 'Information to Predict:\n'
    system_prompt += '<few-shot example prompt for LabEvent>'
    system_prompt += "Your response should be:\n"
    system_prompt += '<few-shot example response for LabEvent>'
    
    system_prompt += 'Information to Predict:\n'
    system_prompt += '<few-shot example prompt for MicrobiologyEvent>'
    system_prompt += "Your response should be:\n"
    system_prompt += '<few-shot example response for MicrobiologyEvent>'
    
    message = {
        "messages": [
            {
                "role": "system",
                "content": system_prompt
            },
            {
                "role": "user",
                "content": f"Preliminary Information:\n{pred['info_pre']}\n\nInformation to Predict:\n{pred['info_pred']}\n\n"
            },
            {
                "role": "assistant",
                "content": pred['info_pred_gt']
            }
        ]
    }
    return message

# %%
json_batch_list = Patient.get_json_batch_list(load_path, batch_size=batch_size)

def filter_messages_by_length(messages, max_length):
    """Filter messages based on the maximum length."""
    if len(messages) > max_length * 4 or \
       len(tokenizer.apply_chat_template(
           messages["messages"],
           tokenize=True,
           add_generation_prompt=True,
           enable_thinking=True,
       )) > max_length:
        return None
    return messages

for json_batch in json_batch_list:
    
    patient_list = Patient.get_patient_list_from_batch(
        json_batch, max_workers=48)
    print(f"Loaded {len(patient_list)} patients from {json_batch[0]} to {json_batch[-1]}")
    
    train_patient_list = random.sample(patient_list, int(len(patient_list) * train_split_ratio))
    val_patient_list = [p for p in patient_list if p not in train_patient_list]

    train_admission_wise_patient_list = Patient.get_admission_wise_patient_list(
        train_patient_list, max_workers=max_workers)

    val_admission_wise_patient_list = Patient.get_admission_wise_patient_list(
        val_patient_list, max_workers=max_workers)

    train_batch_prediction_list = Patient.prepare_prediction_list_from_patient_list(
        train_admission_wise_patient_list,
        max_workers=max_workers,
    )

    val_batch_prediction_list = Patient.prepare_prediction_list_from_patient_list(
        val_admission_wise_patient_list,
        max_workers=max_workers,
    )

    train_batch_message_list = [format(pred) for pred in tqdm(train_batch_prediction_list, desc="Formatting messages")]
    val_batch_message_list = [format(pred) for pred in tqdm(val_batch_prediction_list, desc="Formatting messages")]

    train_batch_message_list = thread_map(
        filter_messages_by_length,
        train_batch_message_list,
        [max_length] * len(train_batch_message_list),
        max_workers=max_workers,
        desc="Filtering messages",
    )
    train_batch_message_list = [msg for msg in train_batch_message_list if msg is not None]
    
    val_batch_message_list = thread_map(
        filter_messages_by_length,
        val_batch_message_list,
        [max_length] * len(val_batch_message_list),
        max_workers=max_workers,
        desc="Filtering messages",
    )
    val_batch_message_list = [msg for msg in val_batch_message_list if msg is not None]

    print(f"Saving {len(train_batch_message_list)} messages to {train_save_path}")
    with open(train_save_path, 'a') as f:
        for message in train_batch_message_list:
            f.write(json.dumps(message) + '\n')

    print(f"Saving {len(val_batch_message_list)} messages to {val_save_path}")
    with open(val_save_path, 'a') as f:
        for message in val_batch_message_list:
            f.write(json.dumps(message) + '\n')