from tqdm import tqdm
import torch
import argparse
import sys
import os
sys.path.append(os.getcwd())
sys.path.append("RadFM/Quick_demo")
from med_datasets.data_util.mimic_cxr_utils import *
from pipeline.eval.eval_metrics import eval_result_dir
from RadFM.Quick_demo.Model.RadFM.multimodality_model import MultiLLaMAForCausalLM
from RadFM.Quick_demo.test import get_tokenizer, combine_and_preprocess


def load_model():
    ckpt_path = "/scratch/pretrained/RadFM/pytorch_model.bin"
    lang_model_path = "/home/Otter-Med/RadFM/Quick_demo/Language_files"
    model = MultiLLaMAForCausalLM(lang_model_path=lang_model_path)
    ckpt = torch.load(ckpt_path, map_location='cpu')
    model.load_state_dict(ckpt)
    model = model.to("cuda")
    model.eval()
    tokenizer, image_padding_tokens = get_tokenizer(lang_model_path)
    return model, tokenizer, image_padding_tokens


def query_radfm(model, tokenizer, image_token, image_paths, text):
    images = [{'img_path': image_paths[0], 'position': 0}]
    text, vision_x = combine_and_preprocess(text, images, image_token)
    with torch.no_grad():
        lang_x = tokenizer(
            text, max_length=2048, truncation=True, return_tensors="pt"
        )['input_ids'].to('cuda')
        vision_x = vision_x.to('cuda')
        generation = model.generate(lang_x, vision_x)
        generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
        return generated_texts[0]


if __name__ == "__main__":
    method = "radfm"
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", default=None)
    args = parser.parse_args()
    task = args.task
    model, tokenizer, image_token = load_model()
    if task != None:
        assert task in ["correction", "history", "template", "comparison"]
    output_dir = Path(f"{method}_{task}_result") if task is not None else Path(f"{method}_result")
    split = "test"
    data_dir = Path("/scratch/datasets/MIMIC-CXR/")
    # if validated:
    #     assert split == "test"
    #     if task is not None:
    #         val_json = data_dir / f"{task}_instructions_val.json"
    #     else:
    #         val_json = data_dir / f"instructions_val.json"
    #     val_study_ids = list(json.load(open(val_json))["data"].keys())
    #     print(val_study_ids)
    os.makedirs(output_dir, exist_ok=True)
    print(f"Saving inference results in {output_dir}")
    dicomid2label = create_id2label_dict(data_dir/"mimic-cxr-2.0.0-metadata.csv")
    studyid2split = create_id2split_dict(data_dir/"mimic-cxr-2.0.0-split.csv")
    studyid2path = create_id2path_dict(data_dir/"mimic-cxr-2.0.0-metadata.csv")
    metadata = pd.read_csv(data_dir/"mimic-cxr-2.0.0-metadata.csv")

    for patient_path in tqdm((data_dir/"files").glob("p*/p*"), total=len(list((data_dir/"files").glob("p*/p*")))):
        patient_id = patient_path.name
        for study_path in patient_path.glob("s*"):
            study_id = study_path.name
            out_file_path = output_dir / f"{study_id}.txt"
            if studyid2split[study_id[1:]] != split: continue
            if os.path.exists(out_file_path): continue
            if task == "comparison":
                current_report_path = studyid2path[study_id]
                previous_report_path = get_previous_report_path(Path(current_report_path), metadata)
                if previous_report_path is None: continue
                _, previous_findings, _ = parse_report(data_dir / previous_report_path)
                if len(previous_findings) == 0: continue
            elif task is not None:
                generated_path = data_dir / "files" / f"reports_{task}" / f"{study_id}.txt"
                if not os.path.exists(generated_path): continue
                generated_data = parse_generated(generated_path, task)
                if generated_data is None: continue
            image_path_list = [str(path)[len(str(data_dir)) + 1:] for path in list(study_path.glob("*.jpg"))]
            image_label_list = [dicomid2label[path.split('/')[-1][:-4]] for path in image_path_list]
            image_path_list = [image_path_list[i] for i in range(len(image_path_list)) if image_label_list[i] in ["PA", "AP"]]
            if len(image_path_list) == 0: continue
            report_path = data_dir / "files" / "reports" / patient_id[:3] / patient_id / f"{study_id}.txt"
            image_path_list = [data_dir / path for path in image_path_list]
            _, findings, _ = parse_report(report_path)
            gt = findings.lower().strip()
            if gt == "": continue
            if task is None:
                message = "Can you provide a radiology report for this medical image?"
                # message = "Act as a radiologist and write a diagnostic radiology report for the patient based on their chest radiographs:"
            elif task == "template":
                template_str = generated_data["template"]
                message = random.choice(template_instructions).format(template=template_str)
                gt = generated_data["report"].lower().strip()
            elif task == "comparison":
                current_report_path = studyid2path[study_id]
                previous_report_path = get_previous_report_path(Path(current_report_path), metadata)
                _, previous_findings, _ = parse_report(os.path.join(data_dir,previous_report_path))
                message = random.choice(comparison_instructions).format(previous_report=previous_findings)
            elif task == "correction":
                incorrect_str = generated_data["incorrect_report"]
                fix_str = generated_data["instruction"]
                message = "Report: {incorrect_report}" + '\n' + random.choice(correction_instructions)
                message = message.format(incorrect_report=incorrect_str, instructions=fix_str)
            elif task == "history":
                history_str = generated_data["history"]
                message = random.choice(history_instructions).format(history=history_str)
            pred = query_radfm(model=model, tokenizer=tokenizer, text=message, image_token=image_token, image_paths=image_path_list[:2])
            pred = pred.lower().strip()
            findings_idx = pred.find("findings:")
            if findings_idx >= 0:
                pred = pred[findings_idx:]
            f = open(out_file_path, 'w')
            out_str = f"PRED:\n{pred}\n\nGT:\n{gt}\n"
            print(out_str)
            f.write(out_str)
            f.close()
    # eval_result_dir(output_dir)



