import os
import csv
import json
import time
import torch
import transformers
import preprocess as preprocess
from PIL import Image


ECG_SYSTEM_PROMPT = """

You are an expert specializing in cardiology. Your task is to read and analyze an uploaded electrocardiograph (ECG or EKG), 
the patient information is:

"[PATIENT_INFO]"

Based on the uploaded ECG (or EKG) with patient information, summarize ECG Characteristics with Interpretation in following items:
- PR Interval: how many milliseconds, with further Diagnostics and Interpretation of PR Interval
- QT Interval: how many milliseconds, with further Diagnostics and Interpretation of QT Interval
- P-wave: present or not, and explain the reasons. If present, what's the indication
- T-wave: inverted or not, and explain the reasons with morphology. If present, what's the indication with morphology
- ST-segment: what type (normal, elevation, or depression), and explain the reasons
"""

ECG_DEMO = """
--- 

Below is an example answer:

ECG Characteristics with Interpretation:
- PR Interval: 210 milliseconds. The PR interval is slightly prolonged, suggesting a first-degree atrioventricular (AV) block. This is characterized by a delay in the conduction between the atria and ventricles but every atrial impulse still successfully leads to a ventricular response.
- QT Interval: 400 milliseconds. The QT interval is within normal limits for the patient's heart rate, which suggests there is no acute ischemia and no increased risk of ventricular arrhythmias related to QT prolongation.
- P-wave: Present and consistent, indicating normal atrial depolarization before each ventricular depolarization.
- T-wave: Not inverted; normal morphology which suggests no overt myocardial ischemia or electrolyte disturbances in the context of the ECG alone.
- ST-segment: Isoelectric, which is typical of a normal ECG without acute ischemic changes.
---
"""

ECG_INPUT_PROMPT = """

Following the above example, given an ECG (EKG), generate results with same answering format and items, **do not repeat the demo**:

**Output:**

ECG Characteristics with Interpretation:
- PR Interval
- QT Interval
- P-wave
- T-wave
- ST-segment
"""

METRIC_SYSTEM_PROMPT = """
**Task:** Draft a Clinical Interpretation for a Long-term Holter End-of-study Report

**Objective:** As a knowledgeable clinician, your goal is to draft an "Interpretation" of a patient's clinical report, specifically focusing on the diagnosis.

**Input Format:**

1. **Patient Records (Tabular Data):**
- You will receive tabular data containing a patient's cardiological records with three columns: 
     - **Metric:** Clinical metrics (or variables)
     - **Description:** Description for each metric
     - **Value:** Specific values for each metric (if a metric is absent, the value will be "not present")

2. **Clinical Findings:** written from patient records.

**Task Description:**

A. Using the provided inputs (patient's tabular record and clinical findings), generate an "Interpretation" that reflects a specific clinical diagnosis for the patient.

B. The "Interpretation" should be itemized, including both quantified observations and interpretative conclusions.

C. The following items must be included in the "Interpretation":
- AF/AFL
- VEB
- VT
- SVEB
- SVT
- Pause
- Block
- Sinus
- Symptoms
- QT interval
- PR interval
- P-wave
- T-wave
- ST-segment

D. If an item is not present, simply state "not present" in the interpretation. Do not omit any required item.

"""

METRIC_DEMO = """
---

Here’s an example:

**Output:**  
**Interpretation:**
Monitoring started on 2024-Jul-11 at 10:49 and continued for 2 days and 23 hours.
- AF/AFL: AF/AFL was present(98.9%). The Longest episode was 18:49:30, Day 1 / 16:53:30 and the Fastest episode was 163 bpm, Day 2 / 00:32:27.
- VEB: 2603 isolated (1.95% burden), 58 couplets, 4 bigeminy, 2 trigeminy episodes.
- VT: Total number of VT episodes was 32. The longest episode was 104 beats, Day 2 / 01:25:33.
- SVEB: 3474 isolated (2.60% burden), 83 couplets, 36 bigeminy, 52 trigeminy episodes.
- SVT: Total number of SVT episodes was 6. The longest episode was 5 beats, Day 2 / 02:42:42.
- Pause: Total number of Pauses was 1. The longest episode was 2,440ms, Day 1 / 15:28:12.
- Block: No Blocks Present.
- Sinus: Sinus rhythm was present(94.7%). The range was 78 - 115 bpm, Avg 96 bpm.
- Symptoms: None Present.

---

Based on the example above, use the provided "Patient Records" and "Findings" to generate an "Interpretation."
"""

METRIC_INPUT_PROMPT = """
**Patient Records (Metric | Description | Value):**  
[RECORD]

**Findings:**  
[METRIC_FINDING]
[ECG_FINDING]

**Interpretation:**
Monitoring started on [DATE] at [TIME] and continued for [DURATION].
- AF/AFL: [Interpretation on AF/AFL]
- VEB: [Interpretation on VEB]
- VT: [Interpretation on VT]
- SVEB: [Interpretation on SVEB]
- SVT: [Interpretation on SVT]
- Pause: [Interpretation on Pause]
- Block: [Interpretation on Block]
- Sinus: [Interpretation on Sinus]
- Symptoms: [Interpretation on Symptoms]
- QT interval: [Interpretation on QT interval]
- PR interval: [Interpretation on PR interval]
- P-wave: [Interpretation on P-wave]
- T-wave: [Interpretation on T-wave]
- ST-segment: [Interpretation on ST-segment]
"""
    
def vision_llm_inference(image_path,    # json
                image_desc = None,
                system_prompt=ECG_SYSTEM_PROMPT, 
                demo=ECG_DEMO, 
                input_prompt=ECG_INPUT_PROMPT,
                model_path = "/mnt/efs-llm/huggingface/llava-v1.6-vicuna-13b-hf"  # NOTE: replace by your path
):
    vision_pipeline = transformers.pipeline(
        "image-to-text",
        model=model_path,
        model_kwargs={"torch_dtype": torch.bfloat16},
        device_map="auto",
        trust_remote_code=True
    )
    processor = transformers.AutoProcessor.from_pretrained(model_path)
    
    # image = Image.open(image_path)
    data = json.load(open(image_path, "r"))
    image = data["imageContent"]
    image_desc = data["patientInfo"]
    
    system_prompt = system_prompt.replace("[PATIENT_INFO]", image_desc)
    system_prompt = system_prompt + "\n" + demo + "\n" + input_prompt

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": system_prompt},
                {"type": "image"},
            ],
        },
    ]
    prompt = processor.apply_chat_template(messages, add_generation_prompt=True)

    outputs = vision_pipeline(
        image, 
        prompt=prompt, 
        generate_kwargs={"max_new_tokens": 2048}
    )[0]["generated_text"][len(system_prompt):].split("[/INST]")[-1].strip("ASSISTANT:").strip().strip("[]")

    # print(outputs)
    return outputs
    
    
def metric_llm_inference(metric_path, 
                ecg_findings="",
                system_prompt=METRIC_SYSTEM_PROMPT, 
                demo=METRIC_DEMO, 
                input_prompt=METRIC_INPUT_PROMPT,
                model_path="/mnt/efs-llm/huggingface/Meta-Llama-3.1-8B-Instruct"  # NOTE: replace by your path
    ):
    metrics = {}
    if metric_path.endswith(".csv"):
        with open(metric_path, mode='r', newline='') as csv_file:
            reader = csv.reader(csv_file)
            for row in reader:
                if len(row) == 2:  # Ensure the row has exactly two columns
                    key, value = row
                    
                    if key == "Label" and value == "Value":
                        continue # ignore header
                    
                    metrics[key] = value
    else:
        assert metric_path.endswith(".json")
        json_data = json.load(open(metric_path, "r"))
        if "metrics" in json_data:
            json_data = json_data["metrics"]
        for key, value in json_data.items():
            metrics[key] = value
                
    example_metrics = preprocess.metrics
    example_tabular_data = "\n".join(preprocess.get_tabular_data(example_metrics))
    example_findings = "\n".join(preprocess.get_findings(example_metrics))

    tabular_data = "\n".join(preprocess.get_tabular_data(metrics))
    # print("\ntabular_data:\n", tabular_data)
    
    findings = "\n".join(preprocess.get_findings(metrics))
    # print("\nfindings:\n", findings)
    
    input_prompt = input_prompt.replace("[RECORD]", tabular_data)
    input_prompt = input_prompt.replace("[METRIC_FINDING]", findings)
    input_prompt = input_prompt.replace("ECG_FINDING]", ecg_findings)
    # print("\ninput_prompt:\n", input_prompt)
    
    # launch LLM pipeline
    pipeline = transformers.pipeline(
        "text-generation",
        model=model_path,
        model_kwargs={"torch_dtype": torch.bfloat16},
        device_map="auto",
    )
    # tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)

    system_prompt = system_prompt + "\n" + demo
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": input_prompt},
    ]
    
    print("system_prompt length", len(system_prompt.split()))
    print("input_prompt length", len(input_prompt.split()))
    print("\n>>> System instruction\n", system_prompt)
    print("-"*100)
    print("\n>>> Inputs:\n", input_prompt)
    print("-"*100)
    print("\n>>> Generated Texts:\n")
    
    outputs = pipeline(
        messages,
        # tokenizer.apply_chat_template(messages, tokenize=False),
        max_new_tokens=2048,
        pad_token_id=pipeline.tokenizer.eos_token_id
    )
    rst = outputs[0]["generated_text"][-1]["content"]  # {'role': 'assistant', 'content': ...}
    print(rst)
    
    return rst
    
def save(patient, llm, findings, interpretation):
    with open(f'dataset/{patient}/{llm}.json', 'w', encoding='utf-8') as f:
        json.dump({
            "findings": findings.strip(),
            "interpretation": interpretation.strip().split("**Interpretation:**")[-1].strip()
        }, f, ensure_ascii=False, indent=4)
        
# patient = "42d53152"
# ecg_findings = vision_llm_inference(
#     image_path=f"dataset/{patient}/ecg.json",
#     # image_desc="On May 23, 2024 at 05:03 PM, a 25-year-old male patient underwent an ECG examination using a lead-II system. The test was conducted in the evening, which may influence the patient's heart rhythm patterns",
#     model_path="../../../huggingface/llava-v1.6-vicuna-13b-hf"
# ).split("ECG Characteristics with Interpretation:")[-1].strip().split(
# "here are the characteristics and interpretations:")[-1].strip().strip('[]')

# ecg_findings = "\n".join(ecg_findings.split("\n\n")).strip('[]')

# metric_llm_inference(
#     metric_path=f"dataset/{patient}/metrics.json",
#     ecg_findings=ecg_findings,
#     model_path="../../../huggingface/Meta-Llama-3.1-8B-Instruct"
# )