import os
import csv
import json
import torch
from g4f.client import Client
import transformers
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from transformers import BioGptTokenizer, BioGptForCausalLM
import preprocess as preprocess
import multi_agent as ma


SYSTEM_PROMPT = """
**Task:** Draft a Clinical Interpretation for a Long-term Holter End-of-study Report

**Objective:** As a knowledgeable clinician, your goal is to draft an "Interpretation" of a patient's clinical report, specifically focusing on the diagnosis.

**Instruction of Input:**

1. **Patient Records (Tabular Data):**
   - You will receive tabular data containing a patient's cardiological records with three columns: 
     - **Metric:** Clinical metrics (or variables)
     - **Description:** Description for each metric
     - **Value:** Specific values for each metric (if a metric is absent, the value will be "not present")

2. **Clinical Findings:** written from patient records.
   
**Task Description:**

A. Using the provided inputs (patient's tabular record and clinical findings), generate an "Interpretation" that reflects a specific clinical diagnosis for the patient.

B. The "Interpretation" should be itemized, including both quantified observations and interpretative conclusions.

C. The following items must be included in the "Interpretation":
    - AF/AFL
    - VEB
    - VT
    - SVEB
    - SVT
    - Pause
    - Block
    - Sinus
    - Symptoms
    - QT interval
    - PR interval
    - P-wave
    - T-wave
    - ST-segment

D. If an item is not present, simply state "not present" in the interpretation. Do not omit any required item.

"""

DEMO = """
Here’s an example of interpretation:

**Output:**  
**Interpretation:**
Monitoring started on 2024-Jul-11 at 10:49 and continued for 2 days and 23 hours.
- AF/AFL: AF/AFL was present(98.9%). The Longest episode was 18:49:30, Day 1 / 16:53:30 and the Fastest episode was 163 bpm, Day 2 / 00:32:27.
- VEB: 2603 isolated (1.95% burden), 58 couplets, 4 bigeminy, 2 trigeminy episodes.
- VT: Total number of VT episodes was 32. The longest episode was 104 beats, Day 2 / 01:25:33.
- SVEB: 3474 isolated (2.60% burden), 83 couplets, 36 bigeminy, 52 trigeminy episodes.
- SVT: Total number of SVT episodes was 6. The longest episode was 5 beats, Day 2 / 02:42:42.
- Pause: Total number of Pauses was 1. The longest episode was 2,440ms, Day 1 / 15:28:12.
- Block: No Blocks Present.
- Sinus: Sinus rhythm was present(94.7%). The range was 78 - 115 bpm, Avg 96 bpm.
- Symptoms: None Present.

"""

INPUT_PROMPT = """
---
**Patient Records (Metric | Description | Value):**  
[RECORD]

**Findings:**  
[METRIC_FINDING]
[ECG_FINDING]

**Interpretation:**
Monitoring started on [DATE] at [TIME] and continued for [DURATION].
- AF/AFL: [Interpretation on AF/AFL]
- VEB: [Interpretation on VEB]
- VT: [Interpretation on VT]
- SVEB: [Interpretation on SVEB]
- SVT: [Interpretation on SVT]
- Pause: [Interpretation on Pause]
- Block: [Interpretation on Block]
- Sinus: [Interpretation on Sinus]
- Symptoms: [Interpretation on Symptoms]
- QT interval: [Interpretation on QT interval]
- PR interval: [Interpretation on PR interval]
- P-wave: [Interpretation on P-wave]
- T-wave: [Interpretation on T-wave]
- ST-segment: [Interpretation on ST-segment]
"""

SHORT_SYSTEM_PROMPT = SYSTEM_PROMPT.split("**Instruction of Input:**")[0]
SHORT_INPUT_PROMPT = INPUT_PROMPT.split("[RECORD]")[-1]

class BaselineInference():
    def __init__(self, metric_path, image_path):
        self.metric_path = metric_path
        self.image_path = image_path
        self.llava_path = "../../../huggingface/llava-v1.6-vicuna-13b-hf"
        
        # get ECG findings
        self.ecg_findings = ma.vision_llm_inference(
                        image_path=self.image_path,
                        demo="",
                        model_path=self.llava_path).split(
                        "ECG Characteristics with Interpretation:")[-1].strip().split(
                        "here are the characteristics and interpretations:")[-1].strip().strip('[]')
        self.ecg_findings = "\n".join(self.ecg_findings.split("\n\n")).strip('[]')
        
        # read metric file        
        metrics = {}
        if self.metric_path.endswith(".csv"):
            with open(self.metric_path, mode='r', newline='') as csv_file:
                reader = csv.reader(csv_file)
                for row in reader:
                    if len(row) == 2:  # Ensure the row has exactly two columns
                        key, value = row
                        
                        if key == "Label" and value == "Value":
                            continue # ignore header
                        
                        metrics[key] = value
        else:
            assert self.metric_path.endswith(".json")
            json_data = json.load(open(self.metric_path, "r"))
            if "metrics" in json_data:
                json_data = json_data["metrics"]
            for key, value in json_data.items():
                metrics[key] = value

        # get metric-based findings
        tabular_data = "\n".join(preprocess.get_tabular_data(metrics))
        # print("\ntabular_data:\n", tabular_data)
        
        self.findings = "\n".join(preprocess.get_findings(metrics))
        
        system_prompt = SYSTEM_PROMPT + DEMO
        input_prompt = INPUT_PROMPT
        input_prompt = input_prompt.replace("[RECORD]", tabular_data)
        input_prompt = input_prompt.replace("[METRIC_FINDING]", self.findings)
        input_prompt = input_prompt.replace("ECG_FINDING]", self.ecg_findings)
        # print("\ninput_prompt:\n", input_prompt)
        
        short_system_prompt = SHORT_SYSTEM_PROMPT + DEMO
        short_input_prompt = SHORT_INPUT_PROMPT
        short_input_prompt = short_input_prompt.replace("[METRIC_FINDING]", self.findings)
        short_input_prompt = short_input_prompt.replace("ECG_FINDING]", self.ecg_findings)

        self.messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": input_prompt},
        ]
        
        self.short_messages = [
            {"role": "system", "content": short_system_prompt},
            {"role": "user", "content": short_input_prompt},
        ]
    
    
    def message_to_str(self, short = False):
        msg = self.message if not short else self.short_messages
        str_msg = f"""
        [Begin System Prompt]:
        \n
        {msg[0]["content"]}
        \n
        [End System Prompt]
        [Begin Input Prompt]
        \n
        {msg[1]["content"]}
        \n
        [End Input Prompt]
        """
        
        return str_msg
    
    
    def gf4_inference(self, llm_name):
        """
        'gpt-3'
        'gpt-3.5-turbo'
        'gpt-4o'        
        'gpt-4o-mini'
        'gpt-4'  
        'gpt-4-turbo' 
                
        'llama-3-8b'
        'llama-3-70b'
        'llama-3.1-8b'
        'llama-3.1-70b'
        'llama-3.1-405b'
        
        'mixtral-8x7b'
        'mistral-7b'
        
        'gemini'
        'gemini-pro'
        'gemini-flash'
        """
        client = Client()
        response = client.chat.completions.create(
            model=llm_name,
            messages=self.messages,
            max_tokens=8192
        ).choices[0].message.content
        return response
    
    def med42_inference(self):
        
        model_path = "../../../huggingface/Llama3-Med42-8B"
        pipeline = transformers.pipeline(
            "text-generation",
            model=model_path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )

        prompt = pipeline.tokenizer.apply_chat_template(
            self.messages, tokenize=False, add_generation_prompt=False
        )

        stop_tokens = [
            pipeline.tokenizer.eos_token_id,
            pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>"),
        ]

        outputs = pipeline(
            prompt,
            max_new_tokens=1024,
            eos_token_id=stop_tokens,
            do_sample=True,
            temperature=0.4,
            top_k=150,
            top_p=0.75,
        )

        rst = outputs[0]["generated_text"][len(prompt):].strip()
        return rst

    def meditron_inference(self):
        model_path = "../../../huggingface/meditron-7b"
        
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            device_map="auto"         # Automatically map to available devices
        )

        # Initialize the pipeline
        text_generator = transformers.pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
        )

        # Define your prompt
        prompt = self.message_to_str(True)

        # Generate text
        output = text_generator(
            prompt,
            max_new_tokens=1024,
            num_return_sequences=1,
            do_sample=True,
        )

        # Print the generated text
        rst = output[0]['generated_text'][len(prompt):].strip()
        return rst
    
    def palmyra_inference(self):
        model_name = "../../../huggingface/palmyra-med-20b"

        tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
        )

        model_inputs = tokenizer(self.message_to_str(True), return_tensors="pt").to(
            "cuda"
        )

        gen_conf = {
            "max_new_tokens": 1024,
            "do_sample": True,
        }

        out_tokens = model.generate(**model_inputs, **gen_conf)

        response_ids = out_tokens[0][len(model_inputs.input_ids[0]) :]
        output = tokenizer.decode(response_ids, skip_special_tokens=True)
        return output
        
    def biogpt_inference(self):
        model_name_or_path = "../../../huggingface/BioGPT-Large"

        model = BioGptForCausalLM.from_pretrained(model_name_or_path)
        tokenizer = BioGptTokenizer.from_pretrained(model_name_or_path)
        generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device='cuda:1')
        rst = generator(self.message_to_str(True), 
                        max_length=1024, 
                        num_return_sequences=1, 
                        do_sample=True)
        return rst


for root, dirs, files in os.walk("dataset"):
    patients = dirs
    break

for patient in patients:
    baseline = BaselineInference(
        metric_path=f"dataset/{patient}/metrics.json",
        image_path=f"dataset/{patient}/ecg.json"
    )
    
    rst = ""
    for llm in ['gemini-pro']:
        try:
            rst = baseline.gf4_inference(llm)
        except:
            pass
        if "Interpretation" in rst and "message exceed" not in rst:
            break
    ma.save(patient, "gemini-pro", baseline.findings + "\n" + baseline.ecg_findings, rst)
        
        
    rst = ""
    for llm in ['gpt-4o']:
        try:
            rst = baseline.gf4_inference(llm)
        except:
            pass
        if "Interpretation" in rst and "message exceed" not in rst:
            break
    ma.save(patient, "gpt-4o", baseline.findings + "\n" + baseline.ecg_findings, rst)
    
    rst = ""
    for llm in ['llama-3.1-405b']:
        try:
            rst = baseline.gf4_inference(llm)
        except:
            pass
        if "Interpretation" in rst and "message exceed" not in rst:
            break
    ma.save(patient, "llama-3.1-405b", baseline.findings + "\n" + baseline.ecg_findings, rst)

    rst = ""
    for llm in ['llama-3.1-70b']:
        try:
            rst = baseline.gf4_inference(llm)
        except:
            pass
        if "Interpretation" in rst and "message exceed" not in rst:
            break
    ma.save(patient, "llama-3.1-70b", baseline.findings + "\n" + baseline.ecg_findings, rst)
    
    rst = ""       
    for llm in ['mixtral-8x22b']:
        try:
            rst = baseline.gf4_inference(llm)
        except:
            pass
        if "Interpretation" in rst and "message exceed" not in rst:
            break
    ma.save(patient, "mixtral-8x22b", baseline.findings + "\n" + baseline.ecg_findings, rst)

    baseline.biogpt_inference()
    
    rst = ""  
    try:
        rst = baseline.med42_inference()
    except:
        pass
    ma.save(patient, "med42", baseline.findings + "\n" + baseline.ecg_findings, rst)
    

    baseline.palmyra_inference()
    
    rst = ma.metric_llm_inference(
        metric_path=f"dataset/{patient}/metrics.json",
        ecg_findings=baseline.ecg_findings,
        model_path="../../../huggingface/Meta-Llama-3.1-8B-Instruct"
    )
    ma.save(patient, "zodiac", baseline.findings + "\n" + baseline.ecg_findings, rst)