import argparse
import os
import random
import json

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from PIL import Image

from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry

from minigpt_utils import prompt_wrapper, generator
torch.set_num_threads(8)


def parse_args():
    parser = argparse.ArgumentParser(description="Demo")
    parser.add_argument("--cfg_path", default="eval_configs/minigpt4_eval.yaml", help="path to configuration file.")
    parser.add_argument("--gpu_id", type=int, default=0, help="specify the gpu to load the model.")

    parser.add_argument("--img_text_pair_path", type=str)
    parser.add_argument("--output_fold", type=str, default='')

    parser.add_argument(
        "--options",
        nargs="+",
        help="override some settings in the used config, the key-value pair "
        "in xxx=yyy format will be merged into config file (deprecate), "
        "change to --cfg-options instead.",
    )
    args = parser.parse_args()
    return args


def setup_seeds(config):
    seed = config.run_cfg.seed + get_rank()

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    cudnn.benchmark = False
    cudnn.deterministic = True


# ========================================
#             Model Initialization
# ========================================

print('>>> Initializing Models')

args = parse_args()
cfg = Config(args)

model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))

vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
my_generator = generator.Generator(model=model, max_new_tokens=30)
print('Initialization Finished')

out = []
f_in = open(args.img_text_pair_path, 'r')
with torch.no_grad():
    for line in f_in:
        item = json.loads(line)
        img_path = item['img_path']
        user_message = item['prompt']
        raw_response = item['response']
        text_formats = prompt_wrapper.minigpt4_chatbot_prompt
        img = Image.open(os.path.join(img_path)).convert('RGB')
        img_prompt = vis_processor(img).unsqueeze(0).to(model.device)
        img_prompt = [img_prompt]
        prompt = prompt_wrapper.PreferencePrompt(model=model, img_prompts=[img_prompt])
        prompt.update_text_prompt([text_formats], [user_message])
        response, _ = my_generator.generate(prompt)
        print(response)
