from huggingface_hub import hf_hub_download
import torch
import os
from open_flamingo import create_model_and_transforms
from accelerate import Accelerator
from einops import repeat
from PIL import Image
import sys
import argparse
import pandas as pd
from pathlib import Path
from tqdm import tqdm
sys.path.insert(0, os.getcwd())
from med_flamingo.src.utils import FlamingoProcessor
from med_flamingo.scripts.demo_utils import clean_generation
from med_datasets.data_util.mimic_cxr_utils import *
from pipeline.eval.eval_metrics import eval_result_dir


accelerator = Accelerator()  # when using cpu: cpu=True
device = accelerator.device


def load_model():
    llama_path = '/scratch/pretrained/decapoda-research-llama-7b-hf/'
    if not os.path.exists(llama_path):
        raise ValueError('Llama model not yet set up, please check README for instructions!')

    model, image_processor, tokenizer = create_model_and_transforms(
        clip_vision_encoder_path="ViT-L-14",
        clip_vision_encoder_pretrained="openai",
        lang_encoder_path=llama_path,
        tokenizer_path=llama_path,
        cross_attn_every_n_layers=4
    )
    # load med-flamingo checkpoint:
    checkpoint_path = hf_hub_download("med-flamingo/med-flamingo", "model.pt")
    print(f'Downloaded Med-Flamingo checkpoint to {checkpoint_path}')
    model.load_state_dict(torch.load(checkpoint_path, map_location=device), strict=False)
    processor = FlamingoProcessor(tokenizer, image_processor)

    # go into eval model and prepare:
    model = accelerator.prepare(model)
    is_main_process = accelerator.is_main_process
    model.eval()

    return model, processor


def query_medflamingo(model, processor, image_paths, text):
    prompt = "<image> Question: " + message + "Answer: "
    demo_images = [Image.open(path) for path in image_paths]
    pixels = processor.preprocess_images(demo_images)
    pixels = repeat(pixels, 'N c h w -> b N T c h w', b=1, T=1)
    tokenized_data = processor.encode_text(prompt)

    generated_text = model.generate(
        vision_x=pixels.to(device),
        lang_x=tokenized_data["input_ids"].to(device),
        attention_mask=tokenized_data["attention_mask"].to(device),
        max_new_tokens=200,
    )
    response = processor.tokenizer.decode(generated_text[0])
    response = clean_generation(response).lower().strip().replace(text.lower(), "").replace("<image>", "")
    response = response.replace("<unk>", "").replace("<image>", "").replace("question:", "").replace("answer:", "")
    return response


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", default=None)
    args = parser.parse_args()
    task = args.task
    model, processor = load_model()
    if task != None:
        assert task in ["correction", "history", "template", "comparison"]
    output_dir = Path(f"medflamingo_{task}_result") if task is not None else Path(f"medflamingo_result")
    split = "test"
    data_dir = Path("/scratch/datasets/MIMIC-CXR/")
    # if validated:
    #     assert split == "test"
    #     if task is not None:
    #         val_json = data_dir / f"{task}_instructions_val.json"
    #     else:
    #         val_json = data_dir / f"instructions_val.json"
    #     val_study_ids = list(json.load(open(val_json))["data"].keys())
    #     print(val_study_ids)
    os.makedirs(output_dir, exist_ok=True)
    print(f"Saving inference results in {output_dir}")
    dicomid2label = create_id2label_dict(data_dir/"mimic-cxr-2.0.0-metadata.csv")
    studyid2split = create_id2split_dict(data_dir/"mimic-cxr-2.0.0-split.csv")
    studyid2path = create_id2path_dict(data_dir/"mimic-cxr-2.0.0-metadata.csv")
    metadata = pd.read_csv(data_dir/"mimic-cxr-2.0.0-metadata.csv")

    for patient_path in tqdm((data_dir/"files").glob("p*/p*"), total=len(list((data_dir/"files").glob("p*/p*")))):
        patient_id = patient_path.name
        for study_path in patient_path.glob("s*"):
            study_id = study_path.name
            out_file_path = output_dir / f"{study_id}.txt"
            if studyid2split[study_id[1:]] != split: continue
            if os.path.exists(out_file_path): continue
            if task == "comparison":
                current_report_path = studyid2path[study_id]
                previous_report_path = get_previous_report_path(Path(current_report_path), metadata)
                if previous_report_path is None: continue
                _, previous_findings, _ = parse_report(data_dir / previous_report_path)
                if len(previous_findings) == 0: continue
            elif task is not None:
                generated_path = data_dir / "files" / f"reports_{task}" / f"{study_id}.txt"
                if not os.path.exists(generated_path): continue
                generated_data = parse_generated(generated_path, task)
                if generated_data is None: continue
            image_path_list = [str(path)[len(str(data_dir)) + 1:] for path in list(study_path.glob("*.jpg"))]
            image_label_list = [dicomid2label[path.split('/')[-1][:-4]] for path in image_path_list]
            image_path_list = [image_path_list[i] for i in range(len(image_path_list)) if image_label_list[i] in ["PA", "AP"]]
            if len(image_path_list) == 0: continue
            report_path = data_dir / "files" / "reports" / patient_id[:3] / patient_id / f"{study_id}.txt"
            image_path_list = [data_dir / path for path in image_path_list]
            _, findings, _ = parse_report(report_path)
            gt = findings.lower().strip()
            if gt == "": continue
            if task is None:
                message = "Act as a radiologist and write a diagnostic radiology report for the patient based on their chest radiographs:"
            elif task == "template":
                template_str = generated_data["template"]
                message = random.choice(template_instructions).format(template=template_str)
                gt = generated_data["report"].lower().strip()
            elif task == "comparison":
                current_report_path = studyid2path[study_id]
                previous_report_path = get_previous_report_path(Path(current_report_path), metadata)
                _, previous_findings, _ = parse_report(os.path.join(data_dir,previous_report_path))
                message = random.choice(comparison_instructions).format(previous_report=previous_findings)
            elif task == "correction":
                incorrect_str = generated_data["incorrect_report"]
                fix_str = generated_data["instruction"]
                message = "Report: {incorrect_report}" + '\n' + random.choice(correction_instructions)
                message = message.format(incorrect_report=incorrect_str, instructions=fix_str)
            elif task == "history":
                history_str = generated_data["history"]
                message = random.choice(history_instructions).format(history=history_str)
            pred = query_medflamingo(model=model, text=message, processor=processor, image_paths=image_path_list[:2])
            pred = pred.lower().strip()
            findings_idx = pred.find("findings:")
            if findings_idx >= 0:
                pred = pred[findings_idx:]
            f = open(out_file_path, 'w')
            out_str = f"PRED:\n{pred}\n\nGT:\n{gt}\n"
            print(out_str)
            f.write(out_str)
            f.close()
    eval_result_dir(output_dir)



