import os
import random
import sys
import time
import json
import requests
import base64
import argparse
import pandas as pd
from pathlib import Path
from tqdm import tqdm
sys.path.append(os.getcwd())
from med_datasets.data_util.mimic_cxr_utils import *
from pipeline.eval.eval_metrics import eval_result_dir

GPT4V_ENDPOINT = "https://llm00gpt4v.openai.azure.com/openai/deployments/vision0409/chat/completions?api-version=2024-02-15-preview"
GPT4V_KEY = os.environ.get("OPENAI_API_KEY")

headers = {
    "Content-Type": "application/json",
    "api-key": GPT4V_KEY,
}

def query_gpt4v(image_paths: list, instruction: str):
    image_content = []
    for image_path in image_paths:
        encoded_image = base64.b64encode(open(image_path, 'rb').read()).decode('ascii')
        image_content.append({
            "type": "image_url",
            "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}
        })
    payload = {
        "messages": [{
            "role": "system",
            "content": [{
                "type": "text",
                "text": "You are an AI assistant that helps people find information."
            }]
        }, {
            "role": "user",
            "content": [{
                "type": "text",
                "text": instruction
            }] + image_content
        }],
        "temperature": 0.7,
        "top_p": 0.95,
        "max_tokens": 800
    }
    while True:
        try:
            response = requests.post(GPT4V_ENDPOINT, headers=headers, json=payload)
            response.raise_for_status()  # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
            break
        except Exception as e:
            print(e)
            print(type(e))
            time.sleep(5)
    return response.json()["choices"][0]["message"]["content"]



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", default=None)
    args = parser.parse_args()
    task = args.task
    if task != None:
        assert task in ["correction", "history", "template", "comparison"]
    output_dir = Path(f"gpt4v_{task}_result") if task is not None else Path(f"gpt4v_result")
    split = "test"
    data_dir = Path("/scratch/datasets/MIMIC-CXR/")
    # if validated:
    #     assert split == "test"
    #     if task is not None:
    #         val_json = data_dir / f"{task}_instructions_val.json"
    #     else:
    #         val_json = data_dir / f"instructions_val.json"
    #     val_study_ids = list(json.load(open(val_json))["data"].keys())
    #     print(val_study_ids)
    assert GPT4V_KEY is not None
    os.makedirs(output_dir, exist_ok=True)
    print(f"Saving inference results in {output_dir}")
    dicomid2label = create_id2label_dict(data_dir/"mimic-cxr-2.0.0-metadata.csv")
    studyid2split = create_id2split_dict(data_dir/"mimic-cxr-2.0.0-split.csv")
    studyid2path = create_id2path_dict(data_dir/"mimic-cxr-2.0.0-metadata.csv")
    metadata = pd.read_csv(data_dir/"mimic-cxr-2.0.0-metadata.csv")

    for patient_path in tqdm((data_dir/"files").glob("p*/p*"), total=len(list((data_dir/"files").glob("p*/p*")))):
        patient_id = patient_path.name
        for study_path in patient_path.glob("s*"):
            study_id = study_path.name
            out_file_path = output_dir / f"{study_id}.txt"
            if studyid2split[study_id[1:]] != split: continue
            if os.path.exists(out_file_path): continue
            if task == "comparison":
                current_report_path = studyid2path[study_id]
                previous_report_path = get_previous_report_path(Path(current_report_path), metadata)
                if previous_report_path is None: continue
                _, previous_findings, _ = parse_report(data_dir / previous_report_path)
                if len(previous_findings) == 0: continue
            elif task is not None:
                generated_path = data_dir / "files" / f"reports_{task}" / f"{study_id}.txt"
                if not os.path.exists(generated_path): continue
                generated_data = parse_generated(generated_path, task)
                if generated_data is None: continue
            image_path_list = [str(path)[len(str(data_dir)) + 1:] for path in list(study_path.glob("*.jpg"))]
            image_label_list = [dicomid2label[path.split('/')[-1][:-4]] for path in image_path_list]
            image_path_list = [image_path_list[i] for i in range(len(image_path_list)) if image_label_list[i] in ["PA", "AP"]]
            if len(image_path_list) == 0: continue
            report_path = data_dir / "files" / "reports" / patient_id[:3] / patient_id / f"{study_id}.txt"
            image_path_list = [data_dir / path for path in image_path_list]
            _, findings, _ = parse_report(report_path)
            gt = findings.lower().strip()
            if gt == "": continue
            if task is None:
                message = "Act as a radiologist and write a diagnostic radiology report for the patient based on their chest radiographs:"
            elif task == "template":
                template_str = generated_data["template"]
                message = random.choice(template_instructions).format(template=template_str)
                gt = generated_data["report"].lower().strip()
            elif task == "comparison":
                current_report_path = studyid2path[study_id]
                previous_report_path = get_previous_report_path(Path(current_report_path), metadata)
                _, previous_findings, _ = parse_report(os.path.join(data_dir,previous_report_path))
                message = random.choice(comparison_instructions).format(previous_report=previous_findings)
            elif task == "correction":
                incorrect_str = generated_data["incorrect_report"]
                fix_str = generated_data["instruction"]
                message = "Report: {incorrect_report}" + '\n' + random.choice(correction_instructions)
                message = message.format(incorrect_report=incorrect_str, instructions=fix_str)
            elif task == "history":
                history_str = generated_data["history"]
                message = random.choice(history_instructions).format(history=history_str)
            pred = query_gpt4v(instruction=message, image_paths=image_path_list[:2])
            pred = pred.lower().strip()
            findings_idx = pred.find("findings:")
            if findings_idx >= 0:
                pred = pred[findings_idx:]
            f = open(out_file_path, 'w')
            out_str = f"PRED:\n{pred}\n\nGT:\n{gt}\n"
            print(out_str)
            f.write(out_str)
            f.close()
    eval_result_dir(output_dir)
