import argparse
import multiprocessing
import os
import shutil
import sys
import csv
import time
import json
import torch
from pathlib import Path
from tqdm import tqdm
import evaluate
from accelerate import Accelerator
from torchvision import transforms

import numpy as np
sys.path.insert(-1, str(Path(__file__).parent))
sys.path.insert(0, str(Path(__file__).parent.parent.parent))


# from flamingo.modeling_flamingo import FlamingoForConditionalGeneration
from otter.modeling_otter import OtterForConditionalGeneration
from otter.biovil_encoder.image.data.io import load_image
from pipeline.train.distributed import world_info_from_env
from otter.biovil_encoder.image.data.transforms import create_chest_xray_transform_for_inference

from eval_metrics import calculate_nlg_metrics, calculate_chexpert_metrics_mp, calculate_chexbert_metrics



def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--text",
        type=str,
        help="text input or path to a text file",
    )
    parser.add_argument(
        "--images",
        type=str,
        help="path to an image or image folder",
    )
    parser.add_argument(
        "--ckpt",
        type=str,
        help="path to trained model ckpt",
        default="/data/pretrained/mimic_cxr_checkpoint_8.pt"
    )
    parser.add_argument(
        "--medical_vision_encoder_path",
        type=str,
        help="path to pretrained medical vision encoder",
        default="/data/pretrained/biovil_image_resnet50_proj_size_128.pt",
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        help="path to folder to save inference result",
        default=None
    )
    parser.add_argument(
        "--vision_encode_mode",
        type=str,
        choices=["original", "medical_only", "llama_adapter_plus", "llama_adapter_concat"],
        default="llama_adapter_concat",
        help=(
            "mode to encoder input images,"
            "'original' for ignoring medical encoder,"
            "'medical_only' for using medical encoder only and an adapter to reshape output,"
            "'llama_adapter_plus' for using LLaMa-adapter style with two same-size image feature input and add after attention,"
            "'llama_adapter_concat' for using LLaMa-adapter style with learnable prefix prompt"
        )
    )
    parser.add_argument(
        "--dataset_type",
        type=str,
        default="mimic_cxr",
        choices=["mimic_cxr", "bimcv_covid19", "mimicit", "custom_2d", "custom_3d"],
        help="dataset type"
    )
    parser.add_argument(
        "--downsample_frames",
        default=0,
        type=int,
        help="downsample number of input frames, use when using 3D dataset"
    )
    parser.add_argument(
        "--n_beams",
        type=int,
        default=4,
        help="number of beams in inference beam search"
    )
    parser.add_argument(
        "--max-src-length",
        type=int,
        default=512,
        help="the maximum src sequence length",
    )
    parser.add_argument(
        "--max-tgt-length",
        type=int,
        default=512,
        help="the maximum target sequence length",
    )
    parser.add_argument(
        "--patch-image-size",
        type=int,
        default=224
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42
    )
    parser.add_argument(
        "--precision",
        choices=["amp_bf16", "amp_bfloat16", "bf16", "amp", "fp16", "fp32"],
        default="fp16",
        help="Floating point precision.",
    )
    return parser.parse_args()


def do_evaluate(model, dataloader, accelerator, save_dir=None, debug=0, n_beams=3):
    shutil.rmtree(save_dir, ignore_errors=True)
    if save_dir is not None: os.makedirs(save_dir)
    preds, gts, ids = [], [], []
    model.eval()
    count = 0
    for it, batch in tqdm(enumerate(dataloader), desc="inference", total=len(dataloader.dataset)):
        if batch["id"][0] in ids: continue
        images = (
            batch["net_input"]["patch_images"]
            .to(model.device, non_blocking=True)
            .unsqueeze(2)
        )
        med_images = batch["net_input"].get("med_patch_images", None)
        if med_images is not None:
            med_images = med_images.to(model.device, non_blocking=True).unsqueeze(2)

        input_ids = batch["net_input"]["input_ids"].to(
            model.device, non_blocking=True
        )
        input_ids_np = input_ids.cpu().numpy()
        media_token_id = model.text_tokenizer("<image>", add_special_tokens=False)["input_ids"][-1]
        endofchunk_token_id = model.text_tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1]
        answer_token_id = model.text_tokenizer("<answer>", add_special_tokens=False)["input_ids"][-1]
        answer_indices = list(np.where(input_ids_np == answer_token_id)[-1])
        assert len(answer_indices) == 1, "Current only support batch size = 1"
        input_ids, gt_ids = torch.tensor_split(input_ids, answer_indices, dim=-1)
        input_ids = torch.concat((input_ids, torch.tensor([[answer_token_id]]).to(input_ids.device)), dim=-1)
        gt_ids = gt_ids[:,1:-2]

        with torch.no_grad() and accelerator.autocast():
            # start = time.time()
            generated_text = model.generate(
                vision_x=images,
                lang_x=input_ids.to(model.device),
                med_vision_x=med_images,
                attention_mask=torch.ones(input_ids.shape, dtype=torch.int32, device=model.device),
                max_new_tokens=512,
                num_beams=n_beams,
            )
            # print(time.time()-start)

        gt = model.text_tokenizer.decode(gt_ids[0]).strip("\n\t ").lower()
        pred = model.text_tokenizer.decode(generated_text[0]).strip("\n\t ")
        pred = pred[:pred.find("</s>")] if pred.find("</s>") > 0 else pred
        pred = pred[pred.find("<answer>"):] if pred.find("<answer>") > 0 else pred
        pred = pred.replace("<answer>", "").replace("<|endofchunk|>", "").strip("\n\t ").lower()
        if len(gt) == 0 or len(pred) == 0:
            print("\npred:", pred, "gt:", gt)
            continue
        gts.append(gt)
        preds.append(pred)
        ids.append(batch["id"][0])
        count += 1

        if 0 < debug <= it:
            break
    print(count)

    assert len(preds) == len(gts) == len(ids) == len(set(ids))
    if save_dir is not None:
        print(f"saving inference results in {save_dir}")
        if not os.path.exists(save_dir): os.makedirs(save_dir)
        preds_csv = open(os.path.join(save_dir, "preds.csv"), 'w')
        gts_csv = open(os.path.join(save_dir, "gts.csv"), 'w')
        preds_writer = csv.writer(preds_csv, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
        gts_writer = csv.writer(gts_csv, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
        for i in range(len(preds)):
            preds_writer.writerow([preds[i]])
            gts_writer.writerow([gts[i]])
            f = open(os.path.join(save_dir, f"{ids[i]}.txt"), 'w')
            f.write(f"PRED:\n{preds[i]}\n\nGT:\n{gts[i]}\n")
            f.close()
        preds_csv.close()
        gts_csv.close()
    result = calculate_nlg_metrics(preds, gts)
    result.update(calculate_chexbert_metrics(preds, gts))
    result.update(calculate_chexpert_metrics_mp(preds, gts))
    return result

med_transform = create_chest_xray_transform_for_inference(
    resize=512, center_crop_size=480
)
transform = transforms.Resize(224, antialias=True)

def process_image(path_list):
    def pad_or_cut_img_tensors(img_tensors, img_size, num_imgs=2):
        if len(img_tensors) < num_imgs:
            zero_padding = torch.zeros((
                num_imgs - len(img_tensors),
                3,
                img_size,
                img_size
            ), dtype=torch.float)
            img_tensors = torch.cat((img_tensors, zero_padding), dim=0)
        elif len(img_tensors) > num_imgs:
            img_tensors = img_tensors[:num_imgs, :, :, :]
        return img_tensors

    images, med_images = [], []
    for path in path_list:
        image = load_image(path)
        image = med_transform(image)
        med_images.append(image)
        images.append(transform(image))
    images = torch.stack(images)  # (T,C,H,W)
    med_images = torch.stack(med_images)
    images = pad_or_cut_img_tensors(images, 224).unsqueeze(1).unsqueeze(0)
    med_images = pad_or_cut_img_tensors(med_images, 480).unsqueeze(1).unsqueeze(0)
    return images, med_images


def process_text(input_text, tokenizer, max_length):
    input_text = f"<image>User: {input_text} GPT:<answer> "
    tokenized = tokenizer(
        f"{input_text}",
        return_tensors="pt",
        add_special_tokens=False,
        max_length=max_length
    )
    input_ids = torch.cat([torch.tensor([[tokenizer.bos_token_id]]), tokenized["input_ids"]], dim=-1)
    return input_ids


if __name__ == "__main__":
    # multiprocessing.set_start_method("spawn")
    args = parse_args()
    if args.save_dir:
        os.makedirs(args.save_dir, exist_ok=True)
    args.local_rank, args.rank, args.world_size = world_info_from_env()
    accelerator = Accelerator(mixed_precision=args.precision)
    device = accelerator.device

    num_vision_tokens = 10
    if args.vision_encoder_type == "unimedi3d":
        num_vision_tokens = 512
    elif args.vision_encoder_type == "unimedi2d":
        num_vision_tokens = 196
    elif args.dataset_type == "biovil":
        num_vision_tokens = 225
    else:
        raise ValueError("Vision encoder type not recognized")

    model = OtterForConditionalGeneration.from_pretrained(
        "annonymous/openflamingo-9b-hf",
        device_map="auto",
        vision_encode_mode=args.vision_encode_mode,
        downsample_frames=args.downsample_frames,
        num_vision_tokens=num_vision_tokens,
    )
    model.text_tokenizer.add_special_tokens(
        {"additional_special_tokens": ["<|endofchunk|>", "<image>", "<answer>"]}
    )
    model.lang_encoder.resize_token_embeddings(len(model.text_tokenizer))
    model.init_medical_vision_encoder(args)
    model.init_medical_roi_extractor(args)
    ckpt = torch.load(args.ckpt, map_location=device)
    if "model_state_dict" in ckpt.keys():
        ckpt = ckpt["model_state_dict"]
    missing_keys, unexpected_keys = model.load_state_dict(ckpt, strict=False)
    if args.vision_encode_mode == "original":
        assert len(unexpected_keys) == 0
        for key in missing_keys:
            assert "vision_encoder" in key or "lang_encoder" in key
    elif "llama_adapter" in args.vision_encode_mode:
        assert len(unexpected_keys) == 0, unexpected_keys
        for key in missing_keys:
            assert "adapter" not in key

    model.text_tokenizer.padding_side = "left"
    tokenizer = model.text_tokenizer

    model = accelerator.prepare(model)

    while True:
        text = input("Input text or path to text file: ")
        image_path = Path(input("Input path to image or image folder: "))
        start = time.time()
        if os.path.isfile(text):
            with open(text, 'r') as f:
                text = f.read()
        if os.path.isdir(image_path):
            image_path_list = list(image_path.iterdir())
        elif os.path.isfile(image_path):
            image_path_list = [image_path]
        else:
            raise FileNotFoundError(f"{image_path} does not exist")
        image_path_list = [path for path in image_path_list if path.suffix == ".jpg"]

        images, med_images = process_image(image_path_list)
        input_ids = process_text(text, tokenizer, max_length=args.max_tgt_length)

        with torch.no_grad() and accelerator.autocast():
            pred = model.generate(
                vision_x=images.to(model.device),
                lang_x=input_ids.to(model.device),
                med_vision_x=med_images.to(model.device),
                attention_mask=torch.ones(input_ids.shape, dtype=torch.int32, device=model.device),
                max_new_tokens=args.max_tgt_length,
                num_beams=args.n_beams,
            )
        pred = model.text_tokenizer.decode(pred[0]).strip("\n\t ")
        pred = pred[:pred.find("</s>")] if pred.find("</s>") > 0 else pred
        pred = pred[pred.find("<answer>"):] if pred.find("<answer>") > 0 else pred
        pred = pred.replace("<answer>", "").replace("<|endofchunk|>", "").replace("findings:", "").strip("\n\t ").lower()
        print(pred)

        elapse = time.time() - start
        print(f"{elapse} s")
        if args.save_dir:
            with open(os.path.join(args.save_dir, image_path.stem), 'w') as f:
                f.write(pred)

