import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import torch
import os
import json
from tqdm import tqdm
import shortuuid
from pathlib import Path
import sys
sys.path.insert(0, os.getcwd())
sys.path.insert(0, os.path.join(os.getcwd(), "LLaVa_Med"))
from LLaVA_Med.llava import LlavaLlamaForCausalLM
from LLaVA_Med.llava.conversation import conv_templates
from LLaVA_Med.llava.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
import pandas as pd

from PIL import Image
import random
import math
from med_datasets.data_util.mimic_cxr_utils import *
from pipeline.eval.eval_metrics import eval_result_dir


DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"


def query_llava_med(model, tokenizer, image_processor, text, image_path):
    text = text + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * 256 + DEFAULT_IM_END_TOKEN
    conv = conv_templates["simple"].copy()
    conv.append_message(conv.roles[0], text)
    prompt = conv.get_prompt()
    inputs = tokenizer([prompt])
    image = Image.open(image_path)
    image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
    input_ids = torch.as_tensor(inputs.input_ids).cuda()
    keywords = ['###']
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor.unsqueeze(0).half().cuda(),
            do_sample=True,
            temperature=0.7,
            max_new_tokens=1024,
            stopping_criteria=[stopping_criteria])

    input_token_len = input_ids.shape[1]
    n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
    if n_diff_input_output > 0:
        print(f'[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids')
    outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]

    while True:
        cur_len = len(outputs)
        outputs = outputs.strip()
        for pattern in ['###', 'Assistant:', 'Response:']:
            if outputs.startswith(pattern):
                outputs = outputs[len(pattern):].strip()
        if len(outputs) == cur_len:
            break

    try:
        index = outputs.index(conv.sep)
    except ValueError:
        outputs += conv.sep
        index = outputs.index(conv.sep)

    outputs = outputs[:index].strip()

    return outputs

def load_model():
    def patch_config(config):
        patch_dict = {
            "use_mm_proj": True,
            "mm_vision_tower": "openai/clip-vit-large-patch14",
            "mm_hidden_size": 1024
        }

        cfg = AutoConfig.from_pretrained(config)
        if not hasattr(cfg, "mm_vision_tower"):
            print(f'`mm_vision_tower` not found in `{config}`, applying patch and save to disk.')
            for k, v in patch_dict.items():
                setattr(cfg, k, v)
            cfg.save_pretrained(config)

    model_name = "/scratch/pretrained/llava-med"
    disable_torch_init()
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    patch_config(model_name)
    model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).cuda()
    image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)

    mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
    tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
    if mm_use_im_start_end:
        tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)

    vision_tower = model.model.vision_tower[0]
    vision_tower.to(device='cuda', dtype=torch.float16)
    vision_config = vision_tower.config
    vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
    vision_config.use_im_start_end = mm_use_im_start_end
    if mm_use_im_start_end:
        vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
            [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
    image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2

    return model, tokenizer, image_processor


class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords, tokenizer, input_ids):
        self.keywords = keywords
        self.tokenizer = tokenizer
        self.start_len = None
        self.input_ids = input_ids

    def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if self.start_len is None:
            self.start_len = self.input_ids.shape[1]
        else:
            outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
            for keyword in self.keywords:
                if keyword in outputs:
                    return True
        return False

if __name__ == "__main__":
    method = "llavamed"
    model, tokenizer, image_processor = load_model()
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", default=None)
    args = parser.parse_args()
    task = args.task
    if task != None:
        assert task in ["correction", "history", "template", "comparison"]
    output_dir = Path(f"{method}_{task}_result") if task is not None else Path(f"{method}_result")
    split = "test"
    data_dir = Path("/scratch/datasets/MIMIC-CXR/")
    # if validated:
    #     assert split == "test"
    #     if task is not None:
    #         val_json = data_dir / f"{task}_instructions_val.json"
    #     else:
    #         val_json = data_dir / f"instructions_val.json"
    #     val_study_ids = list(json.load(open(val_json))["data"].keys())
    #     print(val_study_ids)
    os.makedirs(output_dir, exist_ok=True)
    print(f"Saving inference results in {output_dir}")
    dicomid2label = create_id2label_dict(data_dir / "mimic-cxr-2.0.0-metadata.csv")
    studyid2split = create_id2split_dict(data_dir / "mimic-cxr-2.0.0-split.csv")
    studyid2path = create_id2path_dict(data_dir / "mimic-cxr-2.0.0-metadata.csv")
    metadata = pd.read_csv(data_dir / "mimic-cxr-2.0.0-metadata.csv")

    for patient_path in tqdm((data_dir / "files").glob("p*/p*"), total=len(list((data_dir / "files").glob("p*/p*")))):
        patient_id = patient_path.name
        for study_path in patient_path.glob("s*"):
            study_id = study_path.name
            out_file_path = output_dir / f"{study_id}.txt"
            if studyid2split[study_id[1:]] != split: continue
            # if os.path.exists(out_file_path): continue
            if task == "comparison":
                current_report_path = studyid2path[study_id]
                previous_report_path = get_previous_report_path(Path(current_report_path), metadata)
                if previous_report_path is None: continue
                _, previous_findings, _ = parse_report(data_dir / previous_report_path)
                if len(previous_findings) == 0: continue
            elif task is not None:
                generated_path = data_dir / "files" / f"reports_{task}" / f"{study_id}.txt"
                if not os.path.exists(generated_path): continue
                generated_data = parse_generated(generated_path, task)
                if generated_data is None: continue
            image_path_list = [str(path)[len(str(data_dir)) + 1:] for path in list(study_path.glob("*.jpg"))]
            image_label_list = [dicomid2label[path.split('/')[-1][:-4]] for path in image_path_list]
            image_path_list = [image_path_list[i] for i in range(len(image_path_list)) if
                               image_label_list[i] in ["PA", "AP"]]
            if len(image_path_list) == 0: continue
            report_path = data_dir / "files" / "reports" / patient_id[:3] / patient_id / f"{study_id}.txt"
            image_path_list = [data_dir / path for path in image_path_list]
            full_report, findings, _ = parse_report(report_path)
            gt = findings.lower().strip()
            if gt == "":
                print(f"No findings in parsed report: {full_report}")
                continue
            if task is None:
                message = "What is shown in this image?"
                # message = "Act as a radiologist and write a diagnostic radiology report for the patient based on their chest radiographs:"
            elif task == "template":
                template_str = generated_data["template"]
                message = random.choice(template_instructions).format(template=template_str)
                gt = generated_data["report"].lower().strip()
            elif task == "comparison":
                current_report_path = studyid2path[study_id]
                previous_report_path = get_previous_report_path(Path(current_report_path), metadata)
                _, previous_findings, _ = parse_report(os.path.join(data_dir, previous_report_path))
                message = random.choice(comparison_instructions).format(previous_report=previous_findings)
            elif task == "correction":
                incorrect_str = generated_data["incorrect_report"]
                fix_str = generated_data["instruction"]
                message = "Report: {incorrect_report}" + '\n' + random.choice(correction_instructions)
                message = message.format(incorrect_report=incorrect_str, instructions=fix_str)
            elif task == "history":
                history_str = generated_data["history"]
                message = random.choice(history_instructions).format(history=history_str)
            pred = query_llava_med(model=model, tokenizer=tokenizer, text=message, image_processor=image_processor, image_path=image_path_list[0])
            pred = pred.lower().strip()
            findings_idx = pred.find("findings:")
            if findings_idx >= 0:
                pred = pred[findings_idx:]
            f = open(out_file_path, 'w')
            out_str = f"PRED:\n{pred}\n\nGT:\n{gt}\n"
            print(out_str)
            f.write(out_str)
            f.close()
    eval_result_dir(output_dir)