import torch
from transformers import AutoConfig
import argparse
from transformers import AutoTokenizer
from model.EventChatModel import EventChatQwenCausalLM
from common.common import get_event_images_list, tokenizer_event_token, generate_event_image, split_event_by_time
import numpy as np
import time 
from PIL import Image
from dataset.conversation import conv_templates
from dataset.constants import EVENT_TOKEN_INDEX, DEFAULT_EVENT_TOKEN, DEFAULT_EV_START_TOKEN, DEFAULT_EV_END_TOKEN, EVENT_PLACEHOLDER, DEFAULT_EVENT_PATCH_TOKEN, IGNORE_INDEX

def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str)
    parser.add_argument("--model_base", type=str, default=None)
    parser.add_argument("--query", type=str, default="What is happening in this scene?")
    parser.add_argument("--conv_mode", type=str, default='eventgpt_qwen')
    parser.add_argument("--sep", type=str, default=",")
    parser.add_argument("--context_len", type=int, default=2048)
    parser.add_argument("--temperature", type=float, default=0.4)
    parser.add_argument("--top_p", type=float, default=1)
    parser.add_argument("--num_beams", type=int, default=1)
    parser.add_argument("--max_new_tokens", type=int, default=512)
    parser.add_argument("--spatial_temporal_encoder", type=bool, default=False)
    parser.add_argument("--event_frame", type=str)
  
    args = parser.parse_args()

    config = AutoConfig.from_pretrained(args.model_path)
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False)
    model = EventChatQwenCausalLM.from_pretrained(args.model_path, 
                                                  attn_implementation="flash_attention_2",
                                                  torch_dtype=torch.bfloat16, 
                                                  config=config)
    
    event_processor = None

    mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
    mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
    if mm_use_im_patch_token:
        tokenizer.add_tokens([DEFAULT_EVENT_PATCH_TOKEN], special_tokens=True)
    if mm_use_im_start_end:
        tokenizer.add_tokens([DEFAULT_EV_START_TOKEN, DEFAULT_EV_END_TOKEN], special_tokens=True)
    # model.resize_token_embeddings(len(tokenizer))

    vision_tower = model.get_visual_tower()
    event_processor = vision_tower.event_processor
    context_len = args.context_len 

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    qs = args.query
    event_token_se = DEFAULT_EV_START_TOKEN + DEFAULT_EVENT_TOKEN + DEFAULT_EV_END_TOKEN
    qs = DEFAULT_EVENT_TOKEN + "\n" + qs

    conv_mode = args.conv_mode
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
        
    event_npy = np.load(args.event_frame, allow_pickle=True)
    event_npy = event_npy.item()
    event_data = event_npy['event_data']
    event_feature = event_npy['event_features']
    event_image_size = [240, 320]
       
    input_ids = tokenizer_event_token(prompt, tokenizer, EVENT_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
    

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            event_data=event_data,
            event_feature = event_feature,
            event_image_sizes=event_image_size,
            do_sample=True if args.temperature > 0 else False,
            temperature=args.temperature,
            top_p=args.top_p,
            num_beams=args.num_beams,
            max_new_tokens=args.max_new_tokens,
            use_cache=True
        )
    
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()   
    print(outputs)