import os
import sys
from typing import Optional
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
from fastapi.responses import JSONResponse
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
from model.LISA import LISAForCausalLM
from model.Llava import conversation as conversation_lib
from model.Llava.mm_utils import tokenizer_image_token
from model.segment_anything.utils.transforms import ResizeLongestSide
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
                         DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
import logging
os.environ["TRANSFORMERS_OFFLINE"] = "1"
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
app = FastAPI()
model = None
tokenizer = None
clip_image_processor = None
transform = None
args = None
class Args:
    version = "xinlai/LISA-13B-llama2-v1"
    vis_save_path = "./vis_output"
    precision = "fp16"
    image_size = 1024
    model_max_length = 1024
    lora_r = 8
    vision_tower = "openai/clip-vit-large-patch14"
    local_rank = 0
    load_in_8bit = True
    load_in_4bit = False
    use_mm_start_end = True
    conv_type = "llava_v1"
args = Args()
def preprocess(
    x,
    pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
    pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
    img_size=args.image_size,
) -> torch.Tensor:
    x = (x - pixel_mean) / pixel_std
    h, w = x.shape[-2:]
    padh = img_size - h
    padw = img_size - w
    x = F.pad(x, (0, padw, 0, padh))
    return x
@app.on_event("startup")
async def load_model():
    global model, tokenizer, clip_image_processor, transform, args
    logger.info("Starting model loading...")

    tokenizer = AutoTokenizer.from_pretrained(
        args.version,
        cache_dir=None,
        model_max_length=args.model_max_length,
        padding_side="right",
        use_fast=False,
    )
    tokenizer.pad_token = tokenizer.unk_token
    args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
    logger.info("Tokenizer loaded")
    torch_dtype = torch.float32
    if args.precision == "bf16":
        torch_dtype = torch.bfloat16
    elif args.precision == "fp16":
        torch_dtype = torch.half
    kwargs = {"torch_dtype": torch_dtype}
    if args.load_in_4bit:
        kwargs.update({
            "torch_dtype": torch.half,
            "load_in_4bit": True,
            "quantization_config": BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                llm_int8_skip_modules=["visual_model"],
            ),
        })
    elif args.load_in_8bit:
        kwargs.update({
            "torch_dtype": torch.half,
            "quantization_config": BitsAndBytesConfig(
                llm_int8_skip_modules=["visual_model"],
                load_in_8bit=True,
            ),
        })
    model = LISAForCausalLM.from_pretrained(
        args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, **kwargs
    )
    logger.info("Model initialized")
    model.config.eos_token_id = tokenizer.eos_token_id
    model.config.bos_token_id = tokenizer.bos_token_id
    model.config.pad_token_id = tokenizer.pad_token_id
    model.get_model().initialize_vision_modules(model.get_model().config)
    vision_tower = model.get_model().get_vision_tower()
    vision_tower.to(dtype=torch_dtype)
    logger.info("Vision tower initialized")
    if args.precision == "bf16":
        model = model.bfloat16().cuda()
    elif args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit):
        vision_tower = model.get_model().get_vision_tower()
        model.model.vision_tower = None
        import deepspeed
        model_engine = deepspeed.init_inference(
            model=model,
            dtype=torch.half,
            replace_with_kernel_inject=True,
            replace_method="auto",
        )
        model = model_engine.module
        model.model.vision_tower = vision_tower.half().cuda()
    elif args.precision == "fp32":
        model = model.float().cuda()
    logger.info("Model moved to device")
    vision_tower.to(device=args.local_rank)
    clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
    transform = ResizeLongestSide(args.image_size)
    model.eval()
    logger.info("LISA model loaded successfully")
@app.post("/process")
async def process_image(
    prompt: str = Form(...),
    image: UploadFile = File(...)
):
    try:
        logger.info(f"Received prompt: {prompt}")
        logger.info(f"Processing image: {image.filename}")
        conv = conversation_lib.conv_templates[args.conv_type].copy()
        conv.messages = []
        full_prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
        if args.use_mm_start_end:
            replace_token = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
            full_prompt = full_prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
        conv.append_message(conv.roles[0], full_prompt)
        conv.append_message(conv.roles[1], "")
        prompt_text = conv.get_prompt()
        logger.info(f"Prompt prepared: {prompt_text}")
        image_bytes = await image.read()
        logger.info("Image bytes read")
        image_np = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)
        if image_np is None:
            raise ValueError("Failed to decode image")
        image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
        original_size_list = [image_np.shape[:2]]
        logger.info(f"Image decoded, shape: {image_np.shape}")
        image_clip = (
            clip_image_processor.preprocess(image_np, return_tensors="pt")["pixel_values"][0]
            .unsqueeze(0)
            .cuda()
        )
        if args.precision == "bf16":
            image_clip = image_clip.bfloat16()
        elif args.precision == "fp16":
            image_clip = image_clip.half()
        else:
            image_clip = image_clip.float()
        logger.info(f"CLIP image processed, shape: {image_clip.shape}")
        image_resized = transform.apply_image(image_np)
        resize_list = [image_resized.shape[:2]]
        image_tensor = (
            preprocess(torch.from_numpy(image_resized).permute(2, 0, 1).contiguous())
            .unsqueeze(0)
            .cuda()
        )
        if args.precision == "bf16":
            image_tensor = image_tensor.bfloat16()
        elif args.precision == "fp16":
            image_tensor = image_tensor.half()
        else:
            image_tensor = image_tensor.float()
        logger.info(f"Model image processed, shape: {image_tensor.shape}")
        input_ids = tokenizer_image_token(prompt_text, tokenizer, return_tensors="pt")
        input_ids = input_ids.unsqueeze(0).cuda()
        logger.info(f"Input IDs prepared, shape: {input_ids.shape}")
        logger.info("Starting model evaluation")
        output_ids, pred_masks = model.evaluate(
            image_clip,
            image_tensor,
            input_ids,
            resize_list,
            original_size_list,
            max_new_tokens=512,
            tokenizer=tokenizer,
        )
        logger.info("Model evaluation completed")
        output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX]
        text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
        text_output = text_output.replace("\n", "").replace("  ", " ")
        logger.info(f"Text output: {text_output}")
        os.makedirs(args.vis_save_path, exist_ok=True)
        image_name = image.filename.split(".")[0]
        mask_paths = []
        masked_img_paths = []
        for i, pred_mask in enumerate(pred_masks):
            if pred_mask.shape[0] == 0:
                logger.info(f"Mask {i} is empty, skipping")
                continue
            pred_mask = pred_mask.detach().cpu().numpy()[0]
            pred_mask = pred_mask > 0
            logger.info(f"Mask {i} processed, shape: {pred_mask.shape}")
            mask_path = f"{args.vis_save_path}/{image_name}_mask_{i}.jpg"
            cv2.imwrite(mask_path, pred_mask * 100)
            mask_paths.append(mask_path)
            logger.info(f"Mask saved: {mask_path}")
            masked_img_path = f"{args.vis_save_path}/{image_name}_masked_img_{i}.jpg"
            save_img = image_np.copy()
            save_img[pred_mask] = (
                image_np * 0.5
                + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
            )[pred_mask]
            save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
            cv2.imwrite(masked_img_path, save_img)
            masked_img_paths.append(masked_img_path)
            logger.info(f"Masked image saved: {masked_img_path}")
        logger.info("Returning response")
        return JSONResponse(content={
            "text_output": text_output,
            "mask_paths": mask_paths,
            "masked_img_paths": masked_img_paths
        })
    except Exception as e:
        logger.error(f"Error in process_image: {str(e)}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
@app.get("/health")
async def health_check():
    return {"status": "healthy", "model_loaded": model is not None}
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8001)