import numpy as np
import torch
import torchvision
import random
import os
import json
import re
import argparse
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from tqdm import tqdm


# Set manual seed for reproducibility
torch.manual_seed(1234)

# Function to generate caption for an image
def generate_caption(image_path, instruction, tokenizer, model):
   # instruction: 'Generate a detailed caption for this image.'
   query = tokenizer.from_list_format([
       {'image': image_path},
       {'text': instruction},
   ])
   response, _ = model.chat(tokenizer, query=query, history=None)
   return response


# Function to draw text on an image
def txt_draw(text, target_size=[512, 512]):
   plt.figure(dpi=300, figsize=(1, 1))
   plt.text(-0.1, 1.1, text, fontsize=3.5, wrap=True, verticalalignment="top", horizontalalignment="left")
   plt.axis('off')
  
   canvas = FigureCanvasAgg(plt.gcf())
   canvas.draw()
   w, h = canvas.get_width_height()
   buf = np.fromstring(canvas.tostring_argb(), dtype=np.uint8)
   buf.shape = (w, h, 4)
   buf = np.roll(buf, 3, axis=2)
   image = Image.frombytes("RGBA", (w, h), buf.tostring())
   image = image.resize(target_size, Image.LANCZOS)
   image = np.asarray(image)[:, :, :3]
  
   plt.close('all')
  
   return image


# Function to convert string to dictionary
def convert_string_to_dict(string_dict):
    # Fixing the string format to match proper JSON format
    fixed_string_dict = (
        string_dict
        .replace("{", '{"')
        .replace("}", '"}')
        .replace(": ", '": "')
        .replace(", ", '", "')
        .replace("'", '"')
        .replace('notable details', 'notable_details')
    )

    # Further adjustments to handle nested dictionaries correctly
    fixed_string_dict = (
        fixed_string_dict
        .replace('": {"', '": {')
        .replace('}, "', '}, ')
        .replace('": {', '": {')
        .replace(', "notable_details', ', "notable_details')
        .replace('"{', '{')
        .replace('}"', '}')
        .replace('"[', '[')
        .replace(']"', ']')
        .replace('""', '"')
    )

    # Final conversion using json.loads
    # Split the string by newline to get individual JSON objects
    json_objects = fixed_string_dict.strip().split('\n')
    # Parse each JSON object into a Python dictionary
    try:
        parsed_dicts = [json.loads(obj) for obj in json_objects]
        return parsed_dicts[0]
    except:
        print(f"Error decoding JSON: {json_objects}")
        return None

# Function to vote on the final description
def vote_final_description(frames):
    final_description = {
        'hair': {'color': None, 'texture': None, 'length': None},
        'skin': {'tone': None, 'features': None},
        'clothes': {'style': None, 'color': None, 'notable_details': None},
        'background': {'elements': None, 'colors': None}
    }
    
    hair_color_votes = {}
    hair_texture_votes = {}
    hair_length_votes = {}
    skin_tone_votes = {}
    skin_features_votes = {}
    clothes_style_votes = {}
    clothes_color_votes = {}
    clothes_details_votes = {}
    background_elements_votes = {}
    background_colors_votes = {}
    
    for frame_id, description in frames.items():
        desc = convert_string_to_dict(description)
        if desc is None:
                continue
        
        # Hair
        hair_color = desc['hair']['color']
        hair_color_votes[hair_color] = hair_color_votes.get(hair_color, 0) + 1
        
        hair_texture = desc['hair']['texture']
        hair_texture_votes[hair_texture] = hair_texture_votes.get(hair_texture, 0) + 1
        
        hair_length = desc['hair']['length']
        hair_length_votes[hair_length] = hair_length_votes.get(hair_length, 0) + 1
        
        # Skin
        skin_tone = desc['skin']['tone']
        skin_tone_votes[skin_tone] = skin_tone_votes.get(skin_tone, 0) + 1
        
        skin_features = desc['skin']['features']
        skin_features_votes[skin_features] = skin_features_votes.get(skin_features, 0) + 1
        
        # Clothes
        clothes_style = desc['clothes']['style']
        clothes_style_votes[clothes_style] = clothes_style_votes.get(clothes_style, 0) + 1
        
        clothes_color = desc['clothes']['color']
        clothes_color_votes[clothes_color] = clothes_color_votes.get(clothes_color, 0) + 1
        
        clothes_details = desc['clothes']['notable_details']
        clothes_details_votes[clothes_details] = clothes_details_votes.get(clothes_details, 0) + 1
        
        # Background
        if desc.get('background', None):
            background_elements = desc['background']['elements']
            if isinstance(background_elements, list):
                for background_elem in background_elements:
                    background_elements_votes[background_elem] = background_elements_votes.get(background_elem, 0) + 1
            else:
                background_elements_votes[background_elements] = background_elements_votes.get(background_elements, 0) + 1
            background_colors = desc['background']['colors']
            if isinstance(background_colors, list):
                for bg_colors in background_colors:
                    background_colors_votes[bg_colors] = background_colors_votes.get(bg_colors, 0) + 1
            else:
                background_colors_votes[background_colors] = background_colors_votes.get(background_colors, 0) + 1
        else:
            background_elements, background_colors = [], []
            background_elements_votes, background_colors_votes = {}, {}
    # Voting for the most common value for each attribute
    final_description['hair']['color'] = max(hair_color_votes, key=hair_color_votes.get) if hair_color_votes else ''
    final_description['hair']['texture'] = max(hair_texture_votes, key=hair_texture_votes.get) if hair_texture_votes else ''
    final_description['hair']['length'] = max(hair_length_votes, key=hair_length_votes.get) if hair_length_votes else ''
    
    final_description['skin']['tone'] = max(skin_tone_votes, key=skin_tone_votes.get) if skin_tone_votes else ''
    final_description['skin']['features'] = max(skin_features_votes, key=skin_features_votes.get) if skin_features_votes else ''
    
    final_description['clothes']['style'] = max(clothes_style_votes, key=clothes_style_votes.get) if clothes_style_votes else ''
    final_description['clothes']['color'] = max(clothes_color_votes, key=clothes_color_votes.get) if clothes_color_votes else ''
    
    # Combining notable details
    notable_details_combined = []
    for detail, count in clothes_details_votes.items():
        if count > 1:
            notable_details_combined.append(detail)
    final_description['clothes']['notable_details'] = ', '.join(notable_details_combined)
    
    # Combining background elements and colors
    background_elements_combined = []
    for element, count in background_elements_votes.items():
        if count > 1:
            background_elements_combined.append(element)
    final_description['background']['elements'] = ', '.join(background_elements_combined)
    
    background_colors_combined = []
    for color, count in background_colors_votes.items():
        if count > 1:
            background_colors_combined.append(color)
    final_description['background']['colors'] = ', '.join(background_colors_combined)
    
    return final_description

def voting_from_json(args, tokenizer, model, instruction):
    if os.path.exists(args.vote_output_path) and args.caption_store is None:
        print(f'Voted captions exist in {args.vote_output_path}. Skip voting.....')
        voted_res = json.load(open(args.vote_output_path, 'r'))
    else:
        voted_res = {}
        captions_path = {i: os.path.join(args.caption_store, i, 'captions.json') for i in os.listdir(args.caption_store)}
        print(f'[INFO] Start Voting | {len(captions_path)} videos in total.')
        for key, cap_path in tqdm(captions_path.items(), desc='Voting Captions'):
            if os.path.exists(cap_path):
                cap = json.load(open(cap_path, 'r'))
                try:
                    final_description = vote_final_description(cap)
                    voted_res[key] = final_description
                except:
                    print(f'Error during processing {key}!!!!!!')
        with open(args.vote_output_path, 'w') as f:
            json.dump(voted_res, f, indent=4)
        print(f'[INFO] Voted captions are saved in {args.vote_output_path}')

    print(f'[INFO] Start Processing | {len(voted_res)} captions in total.')
    processed_captions = {}
    for key, caption_ori in tqdm(voted_res.items(), desc='Processing Voting Results'):
        response, history = model.chat(tokenizer, f"Given caption {str(caption_ori)}, {instruction}", history=None)
        print(f"{key} -- Caption: {response}")
        processed_captions[key] = response.strip('"')

    with open(args.processed_caption_path, 'w') as f:
        json.dump(processed_captions, f, indent=4)
    print(f'[INFO] Processed captions are saved in {args.processed_caption_path}')


# Main function to process videos
def main(args, tokenizer, model):
    # Example Videos
    if args.video_info_json:
        print(f'Loading video info from {args.video_info_json}')
        video_infos = json.load(open(args.video_info_json, 'r'))
        video_paths = [vid_info['video'] for vid_info in video_infos]
    else:
        #! TODO: Remove this laoding
        video_paths = [os.path.join(args.video_folder, i, 'gt.mp4') for i in os.listdir(args.video_folder)]
    print(f'Number of Videos: {len(video_paths)}\nSave Path: {args.output_dir}')

    # caption instruction for QWen-VL-Chat
    # instruction = "Construct a structured description of the person in the image using a dictionary format. Ensure each key is populated with detailed and observable information. Use this exact template and fill in the blanks with specific details you can discern from the image: {hair: {color: 'specify color', texture: 'specify texture', length: 'specify length'}, skin: {tone: 'specify tone', features: 'specify any distinguishing features'}, clothes: {style: 'specify style', color: 'specify color', notable details: 'specify details'}, background: {elements: 'list observable elements', colors: 'list dominant colors'}}. Avoid leaving any blanks unless absolutely necessary."
    # instruction = "Construct a structured description of the person in the image using a dictionary format. \
    #     Ensure each key is thoroughly populated with detailed, observable information without using Unicode characters (\u) and Chinese characters.\
    #     For each description, connect properties using 'and' instead of commas. Follow this precise template and fill in the blanks with specific, discernible details from the image: \
    #     {hair: {color: 'specify color', texture: 'specify texture', length: 'specify length'}, skin: {tone: 'specify tone', features: 'identify any distinguishing features'}, clothes: {style: 'specify style', color: 'specify color', notable details: 'mention specific details'}, background: {elements: 'list observable elements', colors: 'name dominant colors'}}. \
    #     Do not leave any blanks unless absolutely necessary. Provide a clear and comprehensive description for each attribute. For specific accessories, such as glasses also add it into the 'notable details'."
    instruction = "Construct a structured description of the person in the image using a dictionary format. \
        Ensure each key is thoroughly populated with detailed, observable information without using Unicode characters and Chinese characters. \
        For each description, connect properties using 'and' instead of commas, try to avoid using commas. \
        Follow this precise template and fill in the blanks with specific, discernible details from the image: \
        {hair: {color: 'specify color', texture: 'specify texture', length: 'specify length'}, \
        skin: {tone: 'specify tone', features: 'identify any distinguishing features'}, \
        clothes: {style: 'specify style', color: 'specify color', notable details: 'mention specific details', accessories:' any accessories like glasses'}, \
        background: {elements: 'list observable elements', colors: 'name dominant colors'} }. \
        Do not leave any blanks unless absolutely necessary. Provide a clear and comprehensive description for each attribute."

    # Constants
    num_initial_samples = 40
    final_sample_count = 8
    video_captions = {}

    for video_path in tqdm(video_paths, desc='Iterating Videos'):

        # Prepare directory for saving images and captions
        if args.video_info_json:
            video_id = video_path.split('/')[-1].replace('.mp4', '')
        else:
            video_id = video_path.split('/')[-2]
        output_dir = os.path.join(args.output_dir, video_id)
        if os.path.exists(output_dir) and os.path.exists(os.path.join(output_dir, 'captions.json')):
            continue # check if the caption is alr generated
        print(f'Load Video {video_id} ....')
        os.makedirs(output_dir, exist_ok=True)

        # Read video and extract frames
        video_frames, _, video_info = torchvision.io.read_video(
            video_path,
            pts_unit='sec',
            output_format='TCHW'
        )
        video_len = len(video_frames)

         # Adjust sampling strategy
        if video_len >= num_initial_samples:
            clip_video_len = num_initial_samples
            start_idx = random.randint(0, video_len - clip_video_len)
            batch_ids = np.linspace(start_idx, start_idx + clip_video_len - 1, num_initial_samples, dtype=int).tolist()
        else:
            clip_video_len = video_len
            batch_ids = list(range(clip_video_len))

        # If more than final_sample_count frames were sampled, downsample to final_sample_count
        if len(batch_ids) > final_sample_count:
            step = len(batch_ids) // final_sample_count
            batch_ids = batch_ids[::step][:final_sample_count]

        # Process each frame
        captions = {}
        for idx, frame_id in tqdm(enumerate(batch_ids), desc='Iterating Frames'):
            frame = video_frames[frame_id].permute(1, 2, 0).numpy()  # Convert from TCHW to HWC format for PIL
            image = Image.fromarray(frame.astype('uint8'), 'RGB')
            image_path = os.path.join(output_dir, f'sample_{idx}.png')
            image.save(image_path)

            # Generate caption
            caption = generate_caption(image_path, instruction, tokenizer, model)
            captions[frame_id] = caption

            # Draw caption below image
            if args.verbose_images:
                image_instruct = txt_draw(f"image frame: {frame_id}\ngenerated caption:\n {caption}")
                out_image = Image.fromarray(np.concatenate((np.array(image_instruct), np.array(image)), 1))
                out_image.save(os.path.join(output_dir, f'captioned_sample_{idx}.png'))

        with open(os.path.join(output_dir, 'captions.json'), 'w') as f:
            json.dump(captions, f)

        # After generating captions, delete saved frames
        if not args.save_sampled_frames:
            for idx, _ in enumerate(batch_ids):
                frame_path = os.path.join(output_dir, f'sample_{idx}.png')
                os.remove(frame_path)  # Remove the saved frame
                if args.verbose_images:
                    captioned_frame_path = os.path.join(output_dir, f'captioned_sample_{idx}.png')
                    os.remove(captioned_frame_path)  # Also remove the captioned frame if it exists

        # Save video clip
        if args.verbose_video_clip:
            print('Saving Video Clip......')
            clip_path = os.path.join(output_dir, 'video_clip.mp4')
            fps = video_info['video_fps']
            clip_frames = video_frames[batch_ids, ...].permute(0, 2, 3, 1)
            torchvision.io.write_video(
                clip_path,
                clip_frames,
                fps,
                video_codec='h264'
            )

        video_captions[video_id] = captions


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Process videos to extract faces and analyze emotions.')
    # args for captioning each video
    parser.add_argument('--model_path', default='/dockerdata/models/Qwen-VL-Chat/', type=str, help='The path of the captioner model')
    parser.add_argument('--video_folder', default='/root/VExpress/test_samples/training_samples/', type=str, help='Folder containing videos to process')
    parser.add_argument('--video_info_json', default=None, type=str, help='The json file contains the information (path) of the videos')
    parser.add_argument('--verbose_images', action='store_true', help='save each frames with its corresponding captions drawed')
    parser.add_argument('--verbose_video_clip', action='store_true', help='save the sampled video clip')
    parser.add_argument('--save_sampled_frames', action='store_true', help='save the sampled frames from the video clip')
    parser.add_argument('--output_dir', default='/root/VExpress/output/captions/training_samples', type=str, help='Folder to save captions and extracted images')
    # args for voting from captions
    parser.add_argument('--vote_model_path', default='/root/Qwen-14B-Chat', type=str, help='The path of the voting model')
    parser.add_argument('--vote_captions', action='store_true', help='vote captions based on the generated framewise captions')
    parser.add_argument('--caption_store', default=None, type=str, help='Path storing framewise captions for each video')
    parser.add_argument('--vote_output_path', default='/root/VExpress/output/captions/training_samples/collected_captions.json', type=str, help='Folder to save captions and extracted images')
    parser.add_argument('--processed_caption_path', default='/root/VExpress/output/captions/training_samples/processed_captions.json', type=str, help='Folder to save captions and extracted images')
    args = parser.parse_args()

    if args.vote_captions:
        print('[INFO] Loading Qwen-Chat....')
        # Note: The default behavior now has injection attack prevention off.
        tokenizer = AutoTokenizer.from_pretrained(args.vote_model_path, trust_remote_code=True)
        # use bf16.
        model = AutoModelForCausalLM.from_pretrained("/root/Qwen-14B-Chat", device_map="auto", trust_remote_code=True).eval()

        # Specify hyperparameters for generation. But if you use transformers>=4.32.0, there is no need to do this.
        # model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参
        # response, history = model.chat(tokenizer, "", history=None)
        print('[INFO] Loaded Qwen-Chat!')

        instruction = "Please refine the caption to make it concise and remove any sematically duplicate values!!!, without using Unicode characters or Chinese characters.\
            The notable details of the clothes should only mention unique details, and the elements and colors of the background should be listed without repetition.\
            Return only the improved caption, without any extra text or explanations!!! Therefore I can use the returns as the improved captions directly. \
            **Don not** use descriptions such as: 'The refined caption is:', 'caption:', 'Refined caption:', etc." 

        # Voting final caption for each video, based on the json file
        voting_from_json(args, tokenizer, model, instruction)
    else:
        # Generate framewise captions for each video, and then store the reuslt in a json file for each video

        print('[INFO] Loading Qwen-VL-Chat....')
        # Initialize tokenizer and model for caption generation
        tokenizer = AutoTokenizer.from_pretrained(args.model_path, revision='main', use_fast=True, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="cuda", revision='main', trust_remote_code=True).eval()
        print('[INFO] Loaded Qwen-VL-Chat!')

        main(args, tokenizer, model)
