import argparse
import multiprocessing
import os
import shutil
import sys
import csv
import time
import json
import torch
from pathlib import Path
from tqdm import tqdm
import evaluate
import pandas as pd
from accelerate import Accelerator
import numpy as np
sys.path.insert(-1, str(Path(__file__).parent))
sys.path.insert(0, str(Path(__file__).parent.parent.parent))


# from flamingo.modeling_flamingo import FlamingoForConditionalGeneration
from otter.modeling_otter import OtterForConditionalGeneration
from flamingo.modeling_flamingo import FlamingoForConditionalGeneration
from pipeline.train.data import get_data
from pipeline.train.train_utils import get_autocast, get_cast_dtype
from pipeline.train.distributed import world_info_from_env
from med_datasets.data_util.mimic_cxr_utils import CATEGORIES
from otter.biovil_encoder import get_cxr_bert
from eval_metrics import calculate_nlg_metrics, calculate_chexpert_metrics_mp, calculate_chexbert_metrics, eval_result_dir, process_annotation


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--mimicit_path",
        type=str,
        help="path to multi_instruct dataset, this should be /path/to/DC_instruction.json",
    )
    parser.add_argument(
        "--images_path",
        type=str,
        help="path to images_path dataset, this should be /path/to/DC.json",
    )
    parser.add_argument(
        "--train_config_path",
        type=str,
        help="path to train_config_path dataset, this should be /path/to/DC/DC_train.json",
    )
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        help="path to huggingface model or model identifier from local path or huggingface.co",
    )
    parser.add_argument(
        "--medical_vision_encoder_path",
        type=str,
        help="path to pretrained medical vision encoder",
        default="/scratch/pretrained/biovil_image_resnet50_proj_size_128.pt",
    )
    parser.add_argument(
        "--image_path",
        type=str,
        help="path to image or folder of image to inference",
        default="/data/datasets/MIMIC-CXR/files/p10/p10899590/s56700179/"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        help="path to folder to save inference result",
        default=None
    )
    parser.add_argument(
        "--instruction",
        type=str,
        help="instruction to generate text",
        default="",
        # default="Act as a radiologist and write a diagnostic radiology report for the patient based on their chest radiographs:"
    )
    parser.add_argument(
        "--vision_encode_mode",
        type=str,
        choices=["original", "medical_only", "llama_adapter_plus", "llama_adapter_concat"],
        default="original",
        help=(
            "mode to encoder input images,"
            "'original' for ignoring medical encoder,"
            "'medical_only' for using medical encoder only and an adapter to reshape output,"
            "'llama_adapter_plus' for using LLaMa-adapter style with two same-size image feature input and add after attention,"
            "'llama_adapter_concat' for using LLaMa-adapter style with learnable prefix prompt"
        )
    )
    parser.add_argument(
        "--vision_encoder_type",
        type=str,
        default="biovil",
        choices=["unimedi", "unimedi2d", "unimedi3d", "biovil"],
        help="vision encoder type"
    )
    parser.add_argument(
        "--dataset_type",
        type=str,
        default="mimicit",
        choices=["mimic_cxr", "bimcv_covid19", "mimicit", "custom_2d", "custom_3d"],
        help="dataset type"
    )
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="/data/datasets/MIMIC-CXR/processed.csv",
        help="path to multi_instruct dataset, this should be a glob pattern such as vision_language_examples.tsv",
    )
    parser.add_argument(
        "--downsample_frames",
        default=0,
        type=int,
        help="downsample number of input frames, use when using 3D dataset"
    )
    parser.add_argument(
        "--split",
        type=str,
        default="test",
        choices=["train", "validate", "test"],
        help="dataset split"
    )
    parser.add_argument(
        "--n_beams",
        type=int,
        default=4,
        help="number of beams in inference beam search"
    )
    parser.add_argument(
        "--med_pos_emb",
        default="sin",
        choices=["sin", "flamingo"],
        help="Positional embedding for medical vision features, " \
            "sin for fixed sinusoidal embedding, " \
            "flamingo for interpolating original flamingo pos emb, " \
            "must be consistent with training"
    )
    parser.add_argument(
        "--use_med_roi",
        action="store_true",
        help="Extract med roi patch by: one forward through perceiver to get label, " \
            "then use label to extract med roi patch and run another forward"
    )
    parser.add_argument(
        "--chexpert_csv_path",
        default="/scratch/datasets/MIMIC-CXR/chexpert_extracted.csv",
        help="Path to CheXpert label csv"
    )
    parser.add_argument(
        "--max-src-length",
        type=int,
        default=1024,
        help="the maximum src sequence length",
    )
    parser.add_argument(
        "--max-tgt-length",
        type=int,
        default=1024,
        help="the maximum target sequence length",
    )
    parser.add_argument(
        "--patch-image-size",
        type=int,
        default=224
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42
    )
    parser.add_argument(
        "--dummy",
        default=False,
        action="store_true",
        help="dummy image input"
    )
    parser.add_argument(
        "--precision",
        choices=["amp_bf16", "amp_bfloat16", "bf16", "amp", "fp16", "fp32"],
        default="fp16",
        help="Floating point precision.",
    )
    parser.add_argument(
        "--use_lora",
        action="store_true",
        default=False,
        help="Use lora in flamingo xattn layers",
    )
    parser.add_argument(
        "--feature_path",
        default=None,
        type=str,
        help="path to directory containing pre-extracted image features"
    )
    parser.add_argument(
        "--workers",
        type=int,
        default=4
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        choices=[1]
    )
    parser.add_argument(
        "--debug",
        type=int,
        default=0,
        help="number of inference to run"
    )

    return parser.parse_args()


def do_evaluate(model, dataloader, accelerator, save_dir=None, debug=0, n_beams=3, use_med_roi=False, valid_ids=None):
    shutil.rmtree(save_dir, ignore_errors=True)
    if save_dir is not None: os.makedirs(save_dir)
    tokenizer, _ = get_cxr_bert()
    preds, gts, ids = [], [], []
    model.eval()
    count = 0
    for it, batch in tqdm(enumerate(dataloader), desc="inference", total=len(dataloader.dataset)):
        study_id = batch["id"][0]
        if valid_ids is not None and study_id not in valid_ids:
            continue
        if study_id in ids:
            continue
        images = (
            batch["net_input"]["patch_images"]
            .to(model.device, non_blocking=True)
            .unsqueeze(2)
        )
        med_images = batch["net_input"].get("med_patch_images", None)
        if med_images is not None:
            med_images = med_images.to(model.device, non_blocking=True).unsqueeze(2)
        input_ids = batch["net_input"]["input_ids"].to(model.device, non_blocking=True)
        noaug_images = batch["net_input"].get("noaug_patch_images", None)
        orig_images = batch["net_input"].get("orig_patch_images", None)
        if noaug_images is not None:
            noaug_images = noaug_images.to(model.device, non_blocking=True).unsqueeze(2)
        input_ids_np = input_ids.cpu().numpy()
        media_token_id = model.text_tokenizer("<image>", add_special_tokens=False)["input_ids"][-1]
        endofchunk_token_id = model.text_tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1]
        answer_token_id = model.text_tokenizer("<answer>", add_special_tokens=False)["input_ids"][-1]
        answer_indices = list(np.where(input_ids_np == answer_token_id)[-1])
        if len(answer_indices) == 0: continue
        assert len(answer_indices) == 1, "Current only support batch size = 1"
        input_ids, gt_ids = torch.tensor_split(input_ids, answer_indices, dim=-1)
        input_ids = torch.concat((input_ids, torch.tensor([[answer_token_id]]).to(input_ids.device)), dim=-1)
        gt_ids = gt_ids[:,1:-2]

        with torch.no_grad() and accelerator.autocast():
            # start = time.time()
            if use_med_roi:
                label = model.generate(  # first forward only get label
                    vision_x=images,
                    lang_x=input_ids.to(model.device),
                    med_vision_x=med_images,
                    attention_mask=torch.ones(input_ids.shape, dtype=torch.int32, device=model.device),
                    max_new_tokens=512,
                    num_beams=n_beams,
                    label_only=True
                )
                label_text = ", ".join([c for i, c in enumerate(CATEGORIES) if label[i]])
                query_ids = tokenizer.batch_encode_plus(
                    batch_text_or_text_pairs=[label_text],
                    add_special_tokens=True,
                    padding='longest',
                    return_tensors='pt'
                ).input_ids[0]
                query_masks = torch.ones(query_ids.shape, dtype=int)
                generated_text = model.generate(
                    vision_x=images,
                    lang_x=input_ids.to(model.device),
                    med_vision_x=med_images,
                    attention_mask=torch.ones(input_ids.shape, dtype=torch.int32, device=model.device),
                    max_new_tokens=512,
                    num_beams=n_beams,
                    label_only=False,
                    query_ids=query_ids,
                    query_masks=query_masks,
                    noaug_vision_x=noaug_images,
                    orig_vision_x=orig_images
                )
            else:  # Only one forward
                generated_text = model.generate(
                    vision_x=images,
                    lang_x=input_ids.to(model.device),
                    med_vision_x=med_images,
                    attention_mask=torch.ones(input_ids.shape, dtype=torch.int32, device=model.device),
                    max_new_tokens=512,
                    num_beams=n_beams,
                )
            # print(time.time()-start)

        gt = model.text_tokenizer.decode(gt_ids[0]).strip("\n\t ").lower()
        pred = model.text_tokenizer.decode(generated_text[0]).strip("\n\t ")
        pred = pred[:pred.find("</s>")] if pred.find("</s>") > 0 else pred
        pred = pred[pred.find("<answer>"):] if pred.find("<answer>") > 0 else pred
        pred = pred.replace("<answer>", "").replace("<|endofchunk|>", "").strip("\n\t ").lower()
        if len(gt) == 0:
            print("Empty GT:\npred:", pred, "gt:", gt)
        else:
            print(pred)
        gts.append(gt)
        preds.append(pred)
        ids.append(study_id)
        count += 1

        if 0 < debug <= it+1:
            break
    print(count)

    assert len(preds) == len(gts) == len(ids) == len(set(ids))
    # if not debug: assert len(preds) == len(dataloader.dataset)
    if save_dir is not None:
        print(f"saving inference results in {save_dir}")
        if not os.path.exists(save_dir): os.makedirs(save_dir)
        preds_csv = open(os.path.join(save_dir, "preds.csv"), 'w')
        gts_csv = open(os.path.join(save_dir, "gts.csv"), 'w')
        preds_writer = csv.writer(preds_csv, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
        gts_writer = csv.writer(gts_csv, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
        for i in range(len(preds)):
            preds_writer.writerow([preds[i]])
            gts_writer.writerow([gts[i]])
            f = open(os.path.join(save_dir, f"{ids[i]}.txt"), 'w')
            f.write(f"PRED:\n{preds[i]}\n\nGT:\n{gts[i]}\n")
            f.close()
        preds_csv.close()
        gts_csv.close()
    all_result = {}
    result = calculate_nlg_metrics(preds, gts, study_ids=ids, all_results=all_result)
    result.update(calculate_chexbert_metrics(preds, gts))
    result.update(calculate_chexpert_metrics(preds, gts, study_ids=ids, all_results=all_result))
    if save_dir is not None:
        all_result = pd.DataFrame(all_result).T.to_csv(os.path.join(save_dir, "all_result.csv"))
        with open(os.path.join(save_dir, "result.json"), 'w') as f:
            json.dump(result, f, indent=2)
    return result


def main():
    args = parse_args()
    args.local_rank, args.rank, args.world_size = world_info_from_env()
    accelerator = Accelerator(mixed_precision=args.precision)
    device = accelerator.device

    num_vision_tokens = 10
    if args.vision_encoder_type == "unimedi3d":
        num_vision_tokens = 512
    elif args.vision_encoder_type == "unimedi2d":
        num_vision_tokens = 196
    elif args.vision_encoder_type == "biovil":
        num_vision_tokens = 225
    else:
        raise ValueError("Vision encoder type not recognized")

    if "med-flamingo" in args.pretrained_model_name_or_path:
        model = FlamingoForConditionalGeneration.from_pretrained(
            "annonymous/openflamingo-9b-hf",
            device_map="auto",
        )
    else:
        model = OtterForConditionalGeneration.from_pretrained(
            "annonymous/openflamingo-9b-hf",
            device_map="auto",
            vision_encode_mode=args.vision_encode_mode,
            downsample_frames=args.downsample_frames,
            num_vision_tokens=num_vision_tokens,
            use_lora=args.use_lora,
        )
        model.text_tokenizer.add_special_tokens(
            {"additional_special_tokens": ["<|endofchunk|>", "<image>", "<answer>"]}
        )
        model.lang_encoder.resize_token_embeddings(len(model.text_tokenizer))
        model.init_medical_vision_encoder(args)
        model.init_medical_roi_extractor(args)
    ckpt = torch.load(args.pretrained_model_name_or_path, map_location=device)
    if "model_state_dict" in ckpt.keys():
        ckpt = ckpt["model_state_dict"]
    if "med-flamingo" in args.pretrained_model_name_or_path:
        from pipeline.train.instruction_following import parse_med_flamingo_checkpoint
        ckpt = parse_med_flamingo_checkpoint(model.state_dict(), ckpt)
    missing_keys, unexpected_keys = model.load_state_dict(ckpt, strict=False)
    loaded_keys = list(ckpt.keys())
    # for key in unexpected_keys:
    #     print(key)
    #     loaded_keys.remove(key)
    # print(loaded_keys)
    if args.vision_encode_mode == "original":
       assert len(unexpected_keys) == 0
       for key in missing_keys:
           assert "vision_encoder" in key or "lang_encoder" in key
    elif "llama_adapter" in args.vision_encode_mode:
        assert len(unexpected_keys) == 0, unexpected_keys
        for key in missing_keys:
            assert "adapter" not in key

    model.text_tokenizer.padding_side = "left"
    tokenizer = model.text_tokenizer
    dataloader = get_data(args, None, tokenizer, args.dataset_type)[0]

    model, dataloader = accelerator.prepare(model, dataloader)
    valid_ids = process_annotation(args.output_dir, "annotation.csv")
    results = do_evaluate(
        model,
        dataloader,
        accelerator,
        save_dir=args.output_dir,
        n_beams=args.n_beams,
        use_med_roi=args.use_med_roi,
        valid_ids=None,
        debug=args.debug
    )
    print(args.mimicit_path)
    print(json.dumps(results, indent=2))


if __name__ == "__main__":
    multiprocessing.set_start_method("spawn")
    main()
