import transformers
import torch
import csv
import preprocess as preprocess
import time

SYSTEM_PROMPT = """
**Task:** Draft a Clinical Interpretation for a Long-term Holter End-of-study Report

**Objective:** As a knowledgeable clinician, your goal is to draft an "Interpretation" of a patient's clinical report, specifically focusing on the diagnosis.

**Input Format:**

1. **Patient Records (Tabular Data):**
   - You will receive tabular data containing a patient's cardiological records with three columns: 
     - **Metric:** Clinical metrics (or variables)
     - **Description:** Description for each metric
     - **Value:** Specific values for each metric (if a metric is absent, the value will be "not present")

2. **Clinical Findings:** written from patient records.
   
**Task Description:**

A. Using the provided inputs (patient's tabular record and clinical findings), generate an "Interpretation" that reflects a specific clinical diagnosis for the patient.

B. The "Interpretation" should be itemized, including both quantified observations and interpretative conclusions.

C. The following items must be included in the "Interpretation":
   - AF/AFL
   - VEB
   - VT
   - SVEB
   - SVT
   - Pause
   - Block
   - Sinus
   - Symptoms

D. If an item is not present, simply state "not present" in the interpretation. Do not omit any required item.

"""

DEMO = """
---

Here’s an example:

**Output:**  
**Interpretation:**
- AF/AFL: AF/AFL was present(98.9%). The Longest episode was 18:49:30, Day 1 / 16:53:30 and the Fastest episode was 163 bpm, Day 2 / 00:32:27.
- VEB: 2603 isolated (1.95% burden), 58 couplets, 4 bigeminy, 2 trigeminy episodes.
- VT: Total number of VT episodes was 32. The longest episode was 104 beats, Day 2 / 01:25:33.
- SVEB: 3474 isolated (2.60% burden), 83 couplets, 36 bigeminy, 52 trigeminy episodes.
- SVT: Total number of SVT episodes was 6. The longest episode was 5 beats, Day 2 / 02:42:42.
- Pause: Total number of Pauses was 1. The longest episode was 2,440ms, Day 1 / 15:28:12.
- Block: No Blocks Present.
- Sinus: Sinus rhythm was present(94.7%). The range was 78 - 115 bpm, Avg 96 bpm.
- Symptoms: None Present.

---

Based on the example above, use the provided "Patient Records" and "Findings" to generate an "Interpretation."
"""

INPUT_PROMPT = """
**Patient Records (Metric | Description | Value):**  
[RECORD]

**Findings:**  
[FINDING]

**Interpretation:**
- AF/AFL: [Interpretation on AF/AFL]
- VEB: [Interpretation on VEB]
- VT: [Interpretation on VT]
- SVEB: [Interpretation on SVEB]
- SVT: [Interpretation on SVT]
- Pause: [Interpretation on Pause]
- Block: [Interpretation on Block]
- Sinus: [Interpretation on Sinus]
- Symptoms: [Interpretation on Symptoms]
"""
    
def llm_inference(metrics, 
                system_prompt=SYSTEM_PROMPT, 
                demo=DEMO, 
                input_prompt=INPUT_PROMPT,
                model_path = "/mnt/efs-llm/huggingface/Meta-Llama-3.1-8B-Instruct"  # NOTE: replace by your path
):
    
    example_metrics = preprocess.metrics
    example_tabular_data = "\n".join(preprocess.get_tabular_data(example_metrics))
    example_findings = "\n".join(preprocess.get_findings(example_metrics))

    # demo = demo.replace("[EXAMPLE RECORD]", example_tabular_data).replace("[EXAMPLE FINDING]", example_findings)
    
    tabular_data = "\n".join(preprocess.get_tabular_data(metrics))
    # print("\ntabular_data:\n", tabular_data)
    
    findings = "\n".join(preprocess.get_findings(metrics))
    # print("\nfindings:\n", findings)
    
    input_prompt = input_prompt.replace("[RECORD]", tabular_data).replace("[FINDING]", findings)
    # print("\ninput_prompt:\n", input_prompt)
    
    # launch LLM pipeline
    pipeline = transformers.pipeline(
        "text-generation",
        model=model_path,
        model_kwargs={"torch_dtype": torch.bfloat16},
        device_map="auto",
    )

    system_prompt = system_prompt + "\n" + demo
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": input_prompt},
    ]
    
    print("system_prompt length", len(system_prompt.split()))
    print("input_prompt length", len(input_prompt.split()))
    print("\n>>> System instruction\n", system_prompt)
    print("-"*100)
    print("\n>>> Inputs:\n", input_prompt)
    print("-"*100)
    print("\n>>> Generated Texts:\n")
    
    outputs = pipeline(
        messages,
        max_new_tokens=1024,
    )
    rst = outputs[0]["generated_text"][-1]["content"]  # {'role': 'assistant', 'content': ...}
    print(rst)
    
    return rst

# test input
data_path = "./example/metric_ex_1.csv"

with open(data_path, mode='r', newline='') as csv_file:
    metrics = {}
    reader = csv.reader(csv_file)
    for row in reader:
        if len(row) == 2:  # Ensure the row has exactly two columns
            key, value = row
            
            if key == "Label" and value == "Value":
                continue # ignore header
            
            metrics[key] = value
            
start = time.time()
llm_inference(metrics)
print(f"Time {time.time() - start}s")