from huggingface_hub import hf_hub_download
import torch
import os
from torchvision import transforms
from einops import repeat
from PIL import Image
import sys
import argparse
import pandas as pd
from pathlib import Path
from tqdm import tqdm
sys.path.insert(0, os.getcwd())
from cvt2distilgpt2.cvt2distilgpt2_mimic_cxr_chen import CvT2DistilGPT2MIMICXRChen
from med_datasets.data_util.mimic_cxr_utils import *
from pipeline.eval.eval_metrics import eval_result_dir




def load_model():
    model_path = "/scratch/pretrained/epoch=8-val_chen_cider=0.425092.ckpt"
    args = {
        'work_dir': f'/home/Otter-Med/cvt2distilgpt2', 'skip_data': True,
        'task': 'mimic_cxr_chen', 'config': 'config/test_mimic_cxr_chen_cvt2distilgpt2', 'exp_dir': 'experiments', 'dataset_dir': 'datasets', 'ckpt_zoo_dir': 'checkpoints', 'definition': 'CvT2DistilGPT2MIMICXRChen', 'module': 'cvt2distilgpt2_mimic_cxr_chen', 'stages_definition': 'stages', 'stages_module': 'stages', 'train': None, 'trial': 0, 'resume_last': True, 'resume_epoch': None, 'resume_ckpt_path': None, 'warm_start_ckpt_path': None, 'monitor': 'val_chen_cider', 'monitor_mode': 'max', 'test': True, 'test_epoch': None, 'test_ckpt_path': 'checkpoints/mimic_cxr_jpg_chen/cvt_21_to_distilgpt2/epoch=8-val_chen_cider=0.425092.ckpt', 'fast_dev_run': None, 'num_workers': 5, 'devices': 1, 'num_nodes': 1, 'memory': None, 'time_limit': None, 'submit': None, 'qos': None, 'begin': None, 'slurm_cmd_path': None, 'email': None, 'cuda_visible_devices': None, 'venv_path': None, 'config_file_name': 'config/test_mimic_cxr_chen_cvt2distilgpt2.yaml', 'config_name': 'test_mimic_cxr_chen_cvt2distilgpt2', 'config_dir': '/home/Otter-Med/cvt2distilgpt2/config', 'config_full_path': '/home/Otter-Med/cvt2distilgpt2/config/test_mimic_cxr_chen_cvt2distilgpt2.yaml', 'strategy': 'ddp_find_unused_parameters_true', 'encoder_lr': 5e-05, 'decoder_lr': 0.0005, 'mbatch_size': 4, 'every_n_epochs': 1, 'precision': 16, 'decoder_max_len': 128, 'num_test_beams': 4, 'enable_progress_bar': True, 'weights_summary': 'full', 'early_stopping': True, 'patience': 10, 'min_delta': 0.0001, 'deterministic': False, 'exp_dir_trial': 'experiments/mimic_cxr_chen/test_mimic_cxr_chen_cvt2distilgpt2/trial_0', 'warm_start_modules': False
    }
    model = CvT2DistilGPT2MIMICXRChen.load_from_checkpoint(checkpoint_path=model_path, **args, strict=False)
    return model

def process_image(image_path):
    image_transforms = transforms.Compose(
        [
            transforms.Resize(size=384 + 64),
            transforms.CenterCrop(size=[384, 384]),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
        ]
    )
    image = Image.open(image_path)
    image = image.convert("RGB")  # "L" (greyscale) or "RGB".
    image = image_transforms(image).unsqueeze(0).cuda()
    return image


def query_cvt2distilgpt2(model, image_paths, text, num_test_beams=3):
    image = process_image(image_paths[0])
    output_ids = model.generate(num_test_beams, image, text)
    response = model.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
    return response


if __name__ == "__main__":
    method = "cvt2distilgpt2"
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", default=None)
    args = parser.parse_args()
    task = args.task
    model = load_model()
    if task != None:
        assert task in ["correction", "history", "template", "comparison"]
    output_dir = Path(f"{method}_{task}_result") if task is not None else Path(f"{method}_result")
    split = "test"
    data_dir = Path("/scratch/datasets/MIMIC-CXR/")
    # if validated:
    #     assert split == "test"
    #     if task is not None:
    #         val_json = data_dir / f"{task}_instructions_val.json"
    #     else:
    #         val_json = data_dir / f"instructions_val.json"
    #     val_study_ids = list(json.load(open(val_json))["data"].keys())
    #     print(val_study_ids)
    os.makedirs(output_dir, exist_ok=True)
    print(f"Saving inference results in {output_dir}")
    dicomid2label = create_id2label_dict(data_dir/"mimic-cxr-2.0.0-metadata.csv")
    studyid2split = create_id2split_dict(data_dir/"mimic-cxr-2.0.0-split.csv")
    studyid2path = create_id2path_dict(data_dir/"mimic-cxr-2.0.0-metadata.csv")
    metadata = pd.read_csv(data_dir/"mimic-cxr-2.0.0-metadata.csv")

    for patient_path in tqdm((data_dir/"files").glob("p*/p*"), total=len(list((data_dir/"files").glob("p*/p*")))):
        patient_id = patient_path.name
        for study_path in patient_path.glob("s*"):
            study_id = study_path.name
            out_file_path = output_dir / f"{study_id}.txt"
            if studyid2split[study_id[1:]] != split: continue
            if os.path.exists(out_file_path): continue
            if task == "comparison":
                current_report_path = studyid2path[study_id]
                previous_report_path = get_previous_report_path(Path(current_report_path), metadata)
                if previous_report_path is None: continue
                _, previous_findings, _ = parse_report(data_dir / previous_report_path)
                if len(previous_findings) == 0: continue
            elif task is not None:
                generated_path = data_dir / "files" / f"reports_{task}" / f"{study_id}.txt"
                if not os.path.exists(generated_path): continue
                generated_data = parse_generated(generated_path, task)
                if generated_data is None: continue
            image_path_list = [str(path)[len(str(data_dir)) + 1:] for path in list(study_path.glob("*.jpg"))]
            image_label_list = [dicomid2label[path.split('/')[-1][:-4]] for path in image_path_list]
            image_path_list = [image_path_list[i] for i in range(len(image_path_list)) if image_label_list[i] in ["PA", "AP"]]
            if len(image_path_list) == 0: continue
            report_path = data_dir / "files" / "reports" / patient_id[:3] / patient_id / f"{study_id}.txt"
            image_path_list = [data_dir / path for path in image_path_list]
            _, findings, _ = parse_report(report_path)
            gt = findings.lower().strip()
            if gt == "": continue
            if task is None:
                message = "Act as a radiologist and write a diagnostic radiology report for the patient based on their chest radiographs:"
            elif task == "template":
                template_str = generated_data["template"]
                message = random.choice(template_instructions).format(template=template_str)
                gt = generated_data["report"].lower().strip()
            elif task == "comparison":
                current_report_path = studyid2path[study_id]
                previous_report_path = get_previous_report_path(Path(current_report_path), metadata)
                _, previous_findings, _ = parse_report(os.path.join(data_dir,previous_report_path))
                message = random.choice(comparison_instructions).format(previous_report=previous_findings)
            elif task == "correction":
                incorrect_str = generated_data["incorrect_report"]
                fix_str = generated_data["instruction"]
                message = "Report: {incorrect_report}" + '\n' + random.choice(correction_instructions)
                message = message.format(incorrect_report=incorrect_str, instructions=fix_str)
            elif task == "history":
                history_str = generated_data["history"]
                message = random.choice(history_instructions).format(history=history_str)
            pred = query_cvt2distilgpt2(model=model, text=message, image_paths=image_path_list[:2])
            pred = pred.lower().strip()
            findings_idx = pred.find("findings:")
            if findings_idx >= 0:
                pred = pred[findings_idx:]
            f = open(out_file_path, 'w')
            out_str = f"PRED:\n{pred}\n\nGT:\n{gt}\n"
            print(out_str)
            f.write(out_str)
            f.close()
    eval_result_dir(output_dir)



