import torch
import transformers

import sys
import os
from argparse import ArgumentParser


sys.path.append('./')
from videollama2.conversation import conv_templates
from videollama2.constants import DEFAULT_MMODAL_TOKEN, MMODAL_TOKEN_INDEX
from videollama2.mm_utils import get_model_name_from_path, tokenizer_MMODAL_token, process_video, process_image, expand2square
from videollama2.model.builder import load_pretrained_model

model_path = os.path.abspath(os.path.join(os.path.abspath(__file__), '../..'))
assert os.path.exists(model_path)
sys.path.append(model_path)

from openpvsg_dataset import *

def inference_video(video, questions, tokenizer, model, processor,):

    # 2. Visual preprocess (load & transform image or video).
    images = [Image.fromarray(f.cpu().numpy().astype(np.uint8) if isinstance(f, torch.Tensor) else f).convert('RGB') for f in video]
    images = [expand2square(image, tuple(int(x*255) for x in processor.image_mean)) for image in images]

    tensor = processor.preprocess(images, return_tensors='pt')['pixel_values'].half()
    default_mm_token = DEFAULT_MMODAL_TOKEN["VIDEO"]
    modal_token_index = MMODAL_TOKEN_INDEX["VIDEO"]
    tensor = [tensor]

    # 3. text preprocess (tag process & generate prompt).
    question = default_mm_token + "\n" + questions[0]
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], question)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    input_ids = tokenizer_MMODAL_token(prompt, tokenizer, modal_token_index, return_tensors='pt').unsqueeze(0).to(device)

    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            images_or_videos=tensor,
            modal_list=['video'] * len(tensor),
            do_sample=False,
            temperature=0.0,
            max_new_tokens=1024,
            use_cache=True,
        )

    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    return outputs[0]


if __name__ == "__main__":

    # Set up data directories and paths
    dataset = "open_pvsg"
    # cache_file_name = f"gpt_specs_prog_str.json"
    data_file_name = 'pvsg.json'
    save_caption_name = "videollamav2_caption.json"

    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f"../../../../data/{dataset}"))
    assert os.path.exists(data_dir)
    data_nl_dir = os.path.join(data_dir, 'nl2spec')
    assert (os.path.exists(data_dir))
    if not os.path.exists(data_nl_dir):
        os.mkdir(data_nl_dir)

    # cache_path = os.path.join(data_nl_dir, cache_file_name)
    data_path = os.path.join(data_dir, data_file_name)
    # caption2scl = json.load(open(cache_path, 'r'))
    save_caption_path = os.path.join(data_nl_dir, save_caption_name)

    video_save_dir = os.path.join(data_dir, 'pred_video')
    model_dir = os.path.join(data_dir, 'model')

    # Setup argument parser
    parser = ArgumentParser(dataset)
    parser.add_argument("--phase", type=str, default='train')
    parser.add_argument("--n-epochs", type=int, default=50)
    parser.add_argument("--load-model", default=False)
    parser.add_argument("--save-model", default=True)
    parser.add_argument("--video_save_dir", type=str, default=video_save_dir)
    parser.add_argument("--model_type", type=str, default="contrast")
    parser.add_argument("--use-neg-spec",  type=bool, default=False)
    parser.add_argument("--use-neg-kws", type=bool, default=False)
    parser.add_argument("--neg-example-ct", type=int, default=2)
    parser.add_argument("--neg-example-weight", type=int, default=0.1)
    parser.add_argument("--neg_entity_kw_binary_weight", type=int, default=0.1)
    parser.add_argument("--neg_entity_kw_cate_weight", type=int, default=0)

    parser.add_argument("--parallel",  action='store_true')
    parser.add_argument("--report-dir", type=str, default="")

    parser.add_argument("--train-num-top-pairs", type=int, default=10)
    parser.add_argument("--test-num-top-pairs", type=int, default=30)
    parser.add_argument("--max-video-len", type=int, default=12)

    # setup question path
    parser.add_argument("--train-num", type=int, default=5000)
    parser.add_argument("--val-num", type=int, default=1000)
    parser.add_argument("--train-percentage", type=int, default=100)
    parser.add_argument("--test-percentage", type=int, default=100)

    # Training hyperparameters
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--learning-rate", type=float, default=0.000001)
    parser.add_argument("--latent-dim", type=float, default=64)
    parser.add_argument("--model-layer", type=int, default=2)
    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument("--provenance", type=str, default="difftopkproofs")
    parser.add_argument("--train-top-k", type=int, default=3)
    parser.add_argument("--test-top-k", type=int, default=3)
    parser.add_argument("--model-name", type=str, default=None)
    parser.add_argument("--model-dir", type=str, default=model_dir)
    parser.add_argument("--data-dir", type=str, default=data_dir)
    parser.add_argument("--use-cuda", action="store_false")
    parser.add_argument("--gpu", type=int, default=-1)

    args = parser.parse_args()
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    model_name = f"laser_kw_{dataset}_training_{args.train_percentage}_seed_{args.seed}_batch_size_{args.batch_size}_lr_{args.learning_rate}_prov_{args.provenance}_tpk_{args.train_top_k}_negspec_{args.use_neg_spec}_negkw_{args.use_neg_kws}_kwweight_{args.neg_example_weight}"

    if args.model_name is None:
        args.model_name = model_name

    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    with open(data_path, 'r') as f:
        anno = json.load(f)

    # See video id in anno['split'].
    data = {data_dict['video_id']: data_dict for data_dict in anno['data']}

    train_dataset, valid_dataset, train_loader, test_loader = open_pvsg_loader(
        cache_path="",
        dataset_dir=data_dir,
        dataset_name=data_file_name,
        batch_size=args.batch_size,
        device=device,
        training_percentage=args.train_percentage,
        testing_percentage=args.test_percentage,
        max_video_len=args.max_video_len,
        neg_kws=args.use_neg_kws,
        neg_spec=args.use_neg_spec,
        neg_example_ct=args.neg_example_ct,
        require_gpt_spec=False,
        )

    # 1. Initialize the model.
    model_path = 'DAMO-NLP-SG/VideoLLaMA2-7B-16F'
    # Base model inference (only need to replace model_path)
    # model_path = 'DAMO-NLP-SG/VideoLLaMA2-7B-Base'
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, processor, context_len = load_pretrained_model(model_path, None, model_name)
    model = model.to(device)
    model.eval()
    conv_mode = 'llama_2'

    # questions = ["Give a caption to the video."]
    questions = ["Give a detailed caption corresponding to the video."]

    if os.path.exists(save_caption_path):
        all_results = json.load(open(save_caption_path, 'r'))
    else:
        all_results = {}
    with torch.no_grad():
        for datapoint in train_loader:
            video_id = datapoint['batched_ids']
            caption = datapoint['batched_captions']
            assert len(video_id) == 1
            video_id = video_id[0]
            caption = caption[0]

            if video_id in all_results:
                if caption in all_results[video_id]:
                    continue

            batched_reshaped_raw_videos = datapoint['batched_reshaped_raw_videos']

            if not video_id in all_results:
                all_results[video_id] = {}
            all_results[video_id][caption] = inference_video(batched_reshaped_raw_videos, questions, tokenizer, model, processor)

            json.dump(all_results, open(save_caption_path, 'w'))

    print('end')