import torch
import transformers

import sys
import os
from argparse import ArgumentParser
from openai import OpenAI

sys.path.append('./')
from videollama2.conversation import conv_templates
from videollama2.constants import DEFAULT_MMODAL_TOKEN, MMODAL_TOKEN_INDEX
from videollama2.mm_utils import get_model_name_from_path, tokenizer_MMODAL_token, process_video, process_image, expand2square
from videollama2.model.builder import load_pretrained_model

client = OpenAI()

model_path = os.path.abspath(os.path.join(os.path.abspath(__file__), '../..'))
assert os.path.exists(model_path)
sys.path.append(model_path)



from openpvsg_dataset import *
from prompts_gpt import wrap_prompt, user

def inference_video(video, questions, tokenizer, model, processor,):

    # 2. Visual preprocess (load & transform image or video).
    images = [Image.fromarray(f.permute(1, 2, 0).cpu().numpy().astype(np.uint8) if isinstance(f, torch.Tensor) else f).convert('RGB') for f in video]
    images = [expand2square(image, tuple(int(x*255) for x in processor.image_mean)) for image in images]
    
    tensor = processor.preprocess(images, return_tensors='pt')['pixel_values'].half()
    default_mm_token = DEFAULT_MMODAL_TOKEN["VIDEO"]
    modal_token_index = MMODAL_TOKEN_INDEX["VIDEO"]
    tensor = [tensor]

    # 3. text preprocess (tag process & generate prompt).
    question = default_mm_token + "\n" + questions[0]
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], question)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    input_ids = tokenizer_MMODAL_token(prompt, tokenizer, modal_token_index, return_tensors='pt').unsqueeze(0).to(device)

    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            images_or_videos=tensor,
            modal_list=['video'] * len(tensor),
            do_sample=False,
            temperature=0.0,
            max_new_tokens=1024,
            use_cache=True,
        )

    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    return outputs[0]


if __name__ == "__main__":
    
    # Set up data directories and paths
    dataset = "open_pvsg"
    # cache_file_name = f"gpt_specs_prog_str.json"
    data_file_name = 'pvsg.json'
    save_caption_name = "videollamav2_prompt_caption.json"
    
    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f"../../../../data/{dataset}"))
    assert os.path.exists(data_dir)
    data_nl_dir = os.path.join(data_dir, 'nl2spec')
    assert (os.path.exists(data_dir))
    if not os.path.exists(data_nl_dir):
        os.mkdir(data_nl_dir)

    # cache_path = os.path.join(data_nl_dir, cache_file_name)
    data_path = os.path.join(data_dir, data_file_name)
    # caption2scl = json.load(open(cache_path, 'r'))
    save_caption_path = os.path.join(data_nl_dir, save_caption_name)
    
    video_save_dir = os.path.join(data_dir, 'pred_video')
    model_dir = os.path.join(data_dir, 'model')

    # Setup argument parser
    parser = ArgumentParser(dataset)
    parser.add_argument("--phase", type=str, default='train')
    parser.add_argument("--n-epochs", type=int, default=50)
    parser.add_argument("--load-model", default=False)
    parser.add_argument("--save-model", default=True)
    parser.add_argument("--video_save_dir", type=str, default=video_save_dir)
    parser.add_argument("--model_type", type=str, default="contrast")
    parser.add_argument("--use-neg-spec",  type=bool, default=False)
    parser.add_argument("--use-neg-kws", type=bool, default=False)
    parser.add_argument("--neg-example-ct", type=int, default=2)
    parser.add_argument("--neg-example-weight", type=int, default=0.1)
    parser.add_argument("--neg_entity_kw_binary_weight", type=int, default=0.1)
    parser.add_argument("--neg_entity_kw_cate_weight", type=int, default=0)

    parser.add_argument("--parallel",  action='store_true')
    parser.add_argument("--report-dir", type=str, default="")
    
    parser.add_argument("--train-num-top-pairs", type=int, default=10)
    parser.add_argument("--test-num-top-pairs", type=int, default=30)
    parser.add_argument("--max-video-len", type=int, default=12)
    
    # setup question path
    parser.add_argument("--train-num", type=int, default=5000)
    parser.add_argument("--val-num", type=int, default=1000)
    parser.add_argument("--train-percentage", type=int, default=100)
    parser.add_argument("--test-percentage", type=int, default=100)

    # Training hyperparameters
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--learning-rate", type=float, default=0.000001)
    parser.add_argument("--latent-dim", type=float, default=64)
    parser.add_argument("--model-layer", type=int, default=2)
    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument("--provenance", type=str, default="difftopkproofs")
    parser.add_argument("--train-top-k", type=int, default=3)
    parser.add_argument("--test-top-k", type=int, default=3)
    parser.add_argument("--model-name", type=str, default=None)
    parser.add_argument("--model-dir", type=str, default=model_dir)
    parser.add_argument("--data-dir", type=str, default=data_dir)
    parser.add_argument("--use-cuda", action="store_false")
    parser.add_argument("--gpu", type=int, default=-1)
    
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    model_name = f"laser_kw_{dataset}_training_{args.train_percentage}_seed_{args.seed}_batch_size_{args.batch_size}_lr_{args.learning_rate}_prov_{args.provenance}_tpk_{args.train_top_k}_negspec_{args.use_neg_spec}_negkw_{args.use_neg_kws}_kwweight_{args.neg_example_weight}"

    if args.model_name is None:
        args.model_name = model_name

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    
    with open(data_path, 'r') as f:
        anno = json.load(f)
    
    # See video id in anno['split'].
    data = {data_dict['video_id']: data_dict for data_dict in anno['data']}
    
    train_dataset, valid_dataset, train_loader, test_loader = open_pvsg_loader(
        cache_path="",
        dataset_dir=data_dir, 
        dataset_name=data_file_name, 
        batch_size=args.batch_size, 
        device=device, 
        training_percentage=args.train_percentage, 
        testing_percentage=args.test_percentage, 
        max_video_len=args.max_video_len,
        neg_kws=args.use_neg_kws,
        neg_spec=args.use_neg_spec,
        neg_example_ct=args.neg_example_ct,
        require_gpt_spec=False,
        )
    
    # 1. Initialize the model.
    model_path = 'DAMO-NLP-SG/VideoLLaMA2-7B-16F'
    # Base model inference (only need to replace model_path)
    # model_path = 'DAMO-NLP-SG/VideoLLaMA2-7B-Base'
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, processor, context_len = load_pretrained_model(model_path, None, model_name)
    model = model.to(device)
    model.eval()
    conv_mode = 'llama_2'
    
    # questions = ["Give a caption to the video."]
    question_regular = ["Give a detailed caption corresponding to the video."]
    question_driver = ["You are a driver. Give a detailed caption corresponding to the video about the objects you care about."]
    question_security = ["You are a security guard. Give a detailed caption corresponding to the video about the objects you care about."]
    question_cook = ["You are a cook. Give a detailed caption corresponding to the video about the objects you care about."]

    if os.path.exists(save_caption_path):
        all_results = json.load(open(save_caption_path, 'r'))
    else:
        all_results = {}
    with torch.no_grad():
        for datapoint in train_loader:
            video_id = datapoint['batched_ids']
            caption = datapoint['batched_captions']
            assert len(video_id) == 1
            video_id = video_id[0]
            caption = caption[0]
            
            if video_id in all_results:
                if caption in all_results[video_id]:
                    continue
                
            batched_reshaped_raw_videos = datapoint['batched_reshaped_raw_videos']
            
            if not video_id in all_results:
                all_results[video_id] = {}
            # all_results[video_id][caption] = inference_video(batched_reshaped_raw_videos, questions, tokenizer, model, processor)
            regular_cap = inference_video(batched_reshaped_raw_videos, question_regular, tokenizer, model, processor)
            driver_cap = inference_video(batched_reshaped_raw_videos, question_driver, tokenizer, model, processor)
            security_cap = inference_video(batched_reshaped_raw_videos, question_security, tokenizer, model, processor)
            cook_cap = inference_video(batched_reshaped_raw_videos, question_cook, tokenizer, model, processor)

            cap_ls = [regular_cap, driver_cap, security_cap, cook_cap]
            context_ls = ["This is a generic video.", "This is a driving scene.", "This is a security guard senario.", "This is a cooking scene."]
            prompt = wrap_prompt(cap_ls, context_ls, few_shot=True)
            
            cache = {}
            batch_size = len(cap_ls)
            response = client.chat.completions.create(
                model="gpt-4-0125-preview",
                response_format={ "type": "json_object" },
                temperature=0,
                    messages=[
                        {"role": "system", "content": user},
                        {"role": "user", "content": prompt}
                    ]
            )
            
            action_responses = json.loads(response.choices[0].message.content)

            if type(action_responses) == list:
                for action_dict in action_responses:
                    action = action_dict['caption']
                    cache[action] = action_dict

            elif type(action_responses) == dict:
                if len(action_responses) == batch_size:
                    for action_id, res in action_responses.items():
                        cache[res['caption']] = res

                if len(action_responses) == 1:
                    action_responses = list(action_responses.values())[0]
                    if type(action_responses) == list:
                        for action_dict in action_responses:
                            if not 'caption' in action_dict:
                                continue
                            action = action_dict['caption']
                            cache[action] = action_dict
                            
                    elif type(action_responses) == dict:
                        for action_dict in action_responses.values():
                            if not 'caption' in action_dict:
                                continue
                            action = action_dict['caption']
                            cache[action] = action_dict
                            
            print("here")
            break
            # json.dump(all_results, open(save_caption_path, 'w'))
            
    print('end')