import os
import torch
from PIL import Image
from tqdm import tqdm

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path


prompt = """
Provide a detailed yet concise description of this person's face. Include their face shape, eyes, nose, mouth, eyebrows, skin `texture and tone, expression, and any notable features like moles, freckles, or wrinkles.
"""

# ʾָļͼƬļ·бֻȡǰ num 


def traverse_images(folder_path, num=10):
    image_list = []
    image_extensions = (".jpg", ".png")
    cnt = 0
    for root, dirs, files in os.walk(folder_path):
        for file in files:
            if file.lower().endswith(image_extensions):
                image_list.append(os.path.join(root, file))
                cnt += 1
                if cnt == num:
                    return image_list
    return image_list

# е LLavaAgent 


class LLavaAgent:
    def __init__(self, model_path, device='cuda', conv_mode='vicuna_v1', load_8bit=False, load_4bit=True):
        self.device = device
        if torch.device(self.device).index is not None:
            device_map = {'model': torch.device(
                self.device).index, 'lm_head': torch.device(self.device).index}
        else:
            device_map = 'auto'
        model_path = os.path.expanduser(model_path)
        model_name = get_model_name_from_path(model_path)
        tokenizer, model, image_processor, context_len = load_pretrained_model(
            model_path, None, model_name, device=self.device, device_map=device_map,
            load_8bit=load_8bit, load_4bit=load_4bit)
        self.model = model
        self.image_processor = image_processor
        self.tokenizer = tokenizer
        self.context_len = context_len
        self.qs = "Describe this image and its style in a very detailed manner."
        self.conv_mode = conv_mode

        if self.model.config.mm_use_im_start_end:
            self.qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + \
                DEFAULT_IM_END_TOKEN + '\n' + self.qs
        else:
            self.qs = DEFAULT_IMAGE_TOKEN + '\n' + self.qs

        self.conv = conv_templates[self.conv_mode].copy()
        self.conv.append_message(self.conv.roles[0], self.qs)
        self.conv.append_message(self.conv.roles[1], None)
        prompt_conv = self.conv.get_prompt()
        self.input_ids = tokenizer_image_token(
            prompt_conv, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)

    def update_qs(self, qs=None):
        if qs is None:
            qs = self.qs
        else:
            if self.model.config.mm_use_im_start_end:
                qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + \
                    DEFAULT_IM_END_TOKEN + '\n' + qs
            else:
                qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

        self.conv = conv_templates[self.conv_mode].copy()
        self.conv.append_message(self.conv.roles[0], qs)
        self.conv.append_message(self.conv.roles[1], None)
        prompt_conv = self.conv.get_prompt()
        self.input_ids = tokenizer_image_token(
            prompt_conv, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)

    def gen_image_caption(self, imgs, temperature=0.2, top_p=0.7, num_beams=1, qs=None):
        self.update_qs(qs)
        bs = len(imgs)
        input_ids = self.input_ids.repeat(bs, 1)
        img_tensor_list = []
        for image in imgs:
            # ʹ image_processor.preprocess תͼƬΪ tensor
            _image_tensor = self.image_processor.preprocess(
                image, return_tensors='pt')['pixel_values'][0]
            img_tensor_list.append(_image_tensor)
        image_tensor = torch.stack(
            img_tensor_list, dim=0).half().to(self.device)

        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids,
                images=image_tensor,
                do_sample=True if temperature > 0 else False,
                temperature=temperature,
                top_p=top_p,
                num_beams=num_beams,
                max_new_tokens=512,
                use_cache=True)

        outputs = self.tokenizer.batch_decode(
            output_ids, skip_special_tokens=True)
        return outputs


# 
if __name__ == '__main__':
    # ʼ LLavaAgentʵģ·е
    llava_agent = LLavaAgent(
        "/share/home/wangj928/BRL_Lab/FaithDiff/checkpoints/llava_v1.5-13b/llava",
        device='cuda', load_8bit=True, load_4bit=False
    )

    # ָͼƬļ·
    image_folder = "/share/home/wangj928/BRL_Lab/FaithDiff/dataset/arc2face_test/"
    text_folder = "/share/home/wangj928/BRL_Lab/FaithDiff/dataset/arc2face_txt"
    image_file_list = os.listdir(image_folder)

    # ͼƬļ·бע os.path.join ÷
    image_path_list = []
    for image_name in image_file_list:
        # ֻ jpg  png ļչɸҪ
        if image_name.lower().endswith((".jpg", ".png")):
            image_path_list.append(os.path.join(image_folder, image_name))

    # ͼƬ
    images = [Image.open(image_path) for image_path in image_path_list]

    #  LLavaAgent ÿͼƬӦ caption prompt
    captions = llava_agent.gen_image_caption(images, qs=prompt2)

    # ÿ caption ͼƬͬ txt ļ
    for img_path, caption in zip(image_path_list, captions):
        # ȡļչ
        base_name = os.path.splitext(os.path.basename(img_path))[0]
        txt_file = os.path.join(text_folder, base_name + ".txt")
        with open(txt_file, "w", encoding="utf-8") as f:
            f.write(caption)

    print("Captions saved successfully!")
