import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
from modelscope import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import json
from tqdm import tqdm


def init_model():
    # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        "./projects/Qwen2.5-VL/models/Qwen2.5-VL-7B-Instruct",
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        device_map="auto",
    )
    
    processor = AutoProcessor.from_pretrained("./projects/Qwen2.5-VL/models/Qwen2.5-VL-7B-Instruct")
    return model, processor

def process_image(image_path, model=None, processor=None):
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image_path,
                },
                {"type": "text", 
                 "text": "Describe this image",
                 },
            ],
        }
    ]

    # Preparation for inference
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    # Inference: Generation of the output
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    print(output_text)
    return output_text


def process_video(video_path, model=None, processor=None):
    messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "video",
                "video": video_path,
                # "max_pixels": 360 * 420,
                "max_pixels": 448 * 448,
                "fps": 1.0,
            },
            {"type": "text", "text": "Describe this video. Provide details about the person in the video, including their gender, clothing, body type (e.g., slim or overweight), and hair color. Also, describe the objects in the background. Keep the description concise and under 100 words, without any additional information."},
        ],
    }
    ]
    # Preparation for inference
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs, video_kwargs = process_vision_info(messages, return_video_kwargs=True)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        # fps=fps,
        padding=True,
        return_tensors="pt",
        **video_kwargs,
    )
    inputs = inputs.to("cuda")

    # Inference
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    print(output_text)
    return output_text



def generate_json_video_dir(target_dir, driving_dir, driving_mask_dir, driving_face_dir, driving_face_mask_dir, reference_dir, save_json_path):
    json_list = []
    proto_dict = {"file_path": "", "text": "", "type": "", "driving_path": "", "driving_mask_path": "", "driving_face_path": "", "driving_face_mask_path": "", "reference_file_path": ""}
    for filename in tqdm(os.listdir(target_dir)):
        if filename.endswith(".mp4"):
            print(filename)
            target_path = os.path.join(target_dir, filename)
            output_text = process_video(target_path, model, processor)
            proto_dict["file_path"] = target_path
            proto_dict["text"] = output_text[0]
            proto_dict["type"] = "video"
            
            # driving video
            driving_path = os.path.join(driving_dir, filename)
            if os.path.exists(driving_path):
                proto_dict["driving_path"] = driving_path
            else:
                print(f"driving video not found: {driving_path}")
                continue
            # driving mask video
            driving_mask_path = os.path.join(driving_mask_dir, filename)
            if os.path.exists(driving_mask_path):
                proto_dict["driving_mask_path"] = driving_mask_path
            else:
                print(f"driving mask video not found: {driving_mask_path}")
                continue
            # driving face video
            driving_face_path = os.path.join(driving_face_dir, filename)
            if os.path.exists(driving_face_path):
                proto_dict["driving_face_path"] = driving_face_path
            else:
                print(f"driving face video not found: {driving_face_path}")
                continue
            # driving face mask video
            driving_face_mask_path = os.path.join(driving_face_mask_dir, filename)
            if os.path.exists(driving_face_mask_path):
                proto_dict["driving_face_mask_path"] = driving_face_mask_path
            else:
                print(f"driving face mask video not found: {driving_face_mask_path}")
                continue
            # reference image
            reference_path = os.path.join(reference_dir, filename.replace(".mp4", ".png")) # 注意reference图片的后缀是png
            if os.path.exists(reference_path):
                proto_dict["reference_file_path"] = reference_path
            else:
                print(f"reference image not found: {reference_path}")
                continue
            
            json_list.append(proto_dict.copy())
    with open(save_json_path, "w", encoding="utf-8") as f:
        json.dump(json_list, f, indent=4, ensure_ascii=False)
    print(f"json file saved to {save_json_path}")
    

if __name__ == "__main__":
    model, processor = init_model()
    # ============== generate json for codvideo ===============
    target_dir = "./target"
    driving_dir = "./driving"
    driving_mask_dir = "./driving_mask"
    driving_face_dir = "./driving_face"
    driving_face_mask_dir = "./driving_face_mask"
    reference_dir = "./reference"
    
    save_json_path = "./metadata.json"
    generate_json_video_dir(target_dir, driving_dir, driving_mask_dir, driving_face_dir, driving_face_mask_dir, reference_dir, save_json_path)