import transformers
import torch
from PIL import Image
import time


SYSTEM_PROMPT = """

You are an expert specializing in cardiology. Your task is to read and analyze an uploaded electrocardiograph (ECG or EKG), 
the patient information is:

"[PATIENT_INFO]"

Based on the uploaded ECG (or EKG) with patient information, complete the tasks by following steps:

Step-1: Identify the ECG Characteristics

Step-2: Detect the Arrhythmias

Step-3: Provide Diagnostics and Interpretation on following aspects:
    - PR Interval: how many milliseconds, with further Diagnostics and Interpretation of PR Interval
    - QT Interval: how many milliseconds, with further Diagnostics and Interpretation of QT Interval
    - AV Block Degree: what type (first degree, second degree type-I/II/2:1, third degree), and explain the reason
    - AFL (Atrial Flutter): present or not, and explain the reasons. If present, what's the indication
    - VFL (Ventricular Flutter): present or not, and explain the reasons. If present, what's the indication
    - PVC & PAC & junctional & escape: detected or not, explain the reasons and the further Diagnostics and Interpretation
    - P-wave: present or not, and explain the reasons. If present, what's the indication
    - T-wave: inverted or not, and explain the reasons with morphology. If present, what's the indication with morphology
    - J-wave: present or not, and explain the reasons. If present, what's the indication
    - ST-segment: what type (normal, elevation, or depression), and explain the reasons
"""

DEMO = """
--- 

Below is an example answer:

Step-1: Identify the ECG Characteristics
- Heart Rate: 75 bpm
- Rhythm: Regular
- P wave: Visible before each QRS complex
- QRS Duration: 110 milliseconds
- QT Interval: 400 milliseconds
- PR Interval: 210 milliseconds
- ST Segment: Isoelectric

Step-2: Detect the Arrhythmias
- PR Interval Prolongation: Suggests a delay in atrioventricular conduction.
- Regular Rhythm with Normal Rate: Typically rules out tachyarrhythmias.
- Normal QT Interval for Rate: No immediate evidence of long QT syndrome.
- No visible additional waves or significant ST-T changes: Normal morphology without apparent extrasystoles or ischemic changes.

Step-3: Provide Diagnostics and Interpretation
- PR Interval: 210 milliseconds. The PR interval is slightly prolonged, suggesting a first-degree atrioventricular (AV) block. This is characterized by a delay in the conduction between the atria and ventricles but every atrial impulse still successfully leads to a ventricular response.
- QT Interval: 400 milliseconds. The QT interval is within normal limits for the patient's heart rate, which suggests there is no acute ischemia and no increased risk of ventricular arrhythmias related to QT prolongation.
- AV Block Degree: First degree, as indicated by the prolonged PR interval without missing any beats.
- AFL (Atrial Flutter): Not present, as there are no characteristic sawtooth flutter waves in the inferior leads or V1.
- VFL (Ventricular Flutter): Not indicated, as the rhythm is regular without rapid, undulating oscillations typical of VFL.
- PVC & PAC & junctional & escape: Not detected in the provided ECG strip; all complexes are preceded by a P wave, and QRS complexes are normal in duration and morphology.
- P-wave: Present and consistent, indicating normal atrial depolarization before each ventricular depolarization.
- T-wave: Not inverted; normal morphology which suggests no overt myocardial ischemia or electrolyte disturbances in the context of the ECG alone.
- J-wave: Not observed in the provided ECG trace.
- ST-segment: Isoelectric, which is typical of a normal ECG without acute ischemic changes.

**Findings:**
The primary diagnosis based on this ECG is a first-degree AV block, which is typically a benign finding but may warrant further investigation if symptomatic (as in this patient's case with dizziness and palpitations). Recommendations would likely include further monitoring and potentially an echocardiogram to evaluate structural heart disease. Additionally, depending on the patient's medical history, a Holter monitor might be considered to rule out intermittent higher-grade AV blocks or other transient arrhythmias not captured during this ECG.

---
"""

INPUT_PROMPT = """

Following the above example, given an ECG (EKG), generate results with same answering format and items, do not repeat the demo:

**Output:**

Step-1: Identify the ECG Characteristics

Step-2: Detect the Arrhythmias

Step-3: Provide Diagnostics and Interpretation on following aspects:
    - PR Interval
    - QT Interval
    - AV Block Degree
    - AFL (Atrial Flutter)
    - VFL (Ventricular Flutter)
    - PVC & PAC & junctional & escape
    - P-wave
    - T-wave
    - J-wave
    - ST-segment
"""

# INPUT_PROMPT = "**Generation Results:**"

def llm_inference(image_path, 
                  image_description,
                  system_prompt=SYSTEM_PROMPT, 
                  demo=DEMO, 
                  input_prompt=INPUT_PROMPT,
                  model_path = "/mnt/efs-llm/huggingface/llava-v1.6-vicuna-13b-hf"  # NOTE: replace by your path
                  ):
    
    vision_pipeline = transformers.pipeline(
            "image-to-text",
            model=model_path,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device_map="auto",
    )
    processor = transformers.AutoProcessor.from_pretrained(model_path)
    
    image = Image.open(image_path)
    print(image)
    
    system_prompt = system_prompt.replace("[PATIENT_INFO]", image_description)
    system_prompt = system_prompt + "\n" + demo + "\n" + input_prompt

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": system_prompt},
                {"type": "image"},
            ],
        },
    ]
    prompt = processor.apply_chat_template(messages, add_generation_prompt=True)

    outputs = vision_pipeline(
        image, 
        prompt=prompt, 
        generate_kwargs={"max_new_tokens": 2048}
    )[0]["generated_text"][len(system_prompt):].split("[/INST]")[-1]

    print(outputs)
    
# start = time.time()
# llm_inference(
#     image_path="example/ecg.jpeg",
#     image_description="On May 23, 2024 at 05:03 PM, a 25-year-old male patient underwent an ECG examination using a lead-II system. The test was conducted in the evening, which may influence the patient's heart rhythm patterns",
# )
# end = time.time()
# print(f"Time {end-start}s")