import argparse
import os
import random
from tqdm import tqdm

import pandas as pd
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from PIL import Image
import json

from minigpt4.common.config import Config
from minigpt4.common.registry import registry

from minigpt_utils import prompt_wrapper, generator, visual_attacker


def rtp_read(text_file):
    dataset = []
    lines = open(text_file).read().split("\n")
    for li in lines:
        obj = json.loads(li)
        if obj['challenging']:
            dataset.append(obj['prompt']['text'])
    random.seed(0)
    random.shuffle(dataset)
    dataset = dataset[:50]
    return dataset


def parse_args():
    parser = argparse.ArgumentParser(description="Demo")
    parser.add_argument("--cfg-path", default="eval_configs/minigpt4_eval.yaml", help="path to configuration file.")
    parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
    parser.add_argument("--image_fold", type=str, default='./image.bmp',
                        help="Image file")
    parser.add_argument("--output_fold", type=str, default='./result.jsonl',
                        help="Output file.")
    parser.add_argument("--test_data_file", type=str,
                        default='/data//mm-safety/data_prepare/do-not-answer/data_en.csv', )
    parser.add_argument(
        "--options",
        nargs="+",
        help="override some settings in the used config, the key-value pair "
             "in xxx=yyy format will be merged into config file (deprecate), "
             "change to --cfg-options instead.",
    )
    parser.add_argument("--ignore_img", action="store_true")
    parser.add_argument("--sample_num", type=int, default=-1)
    parser.add_argument("--random_attack", action="store_true")
    parser.add_argument("--segment_id", type=int, default=1)
    parser.add_argument("--all_segments", type=int, default=1)
    args = parser.parse_args()
    return args


def setup_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    cudnn.benchmark = False
    cudnn.deterministic = True


# ========================================
#             Model Initialization
# ========================================

print('>>> Initializing Models')

args = parse_args()
cfg = Config(args)
setup_seeds(0)
os.makedirs(args.output_fold, exist_ok=True)

model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))

vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
print('Initialization Finished')

my_generator = generator.Generator(model=model, max_new_tokens=30)

# ========================================
#             Inference
# ========================================

# TODO: expose interface.
df = pd.read_csv(args.test_data_file)
harmful_types = set(df["types_of_harm"])
harmful_types.remove('Treat Chatbot as a Human')
harmful_types = sorted(harmful_types)
image_files = sorted(os.listdir(args.image_fold))  # 确保数据顺序
# 确保数据格式
image_files = [item for item in image_files if item.endswith(".jpg") or item.endswith(".bmp") or item.endswith("jpeg")]
# 持续推理，忽略已完成推理文件
covered_files = set([item.split('.')[0] for item in os.listdir(args.output_fold)])
image_files = [item for item in image_files if item.split('.')[0] not in covered_files]
# 按照segment_id和all_segments来切分所有图像文件
segment_len = len(image_files) // args.all_segments
image_files = image_files[segment_len*(args.segment_id-1):segment_len*args.segment_id]
print(f"segment_start: {segment_len*(args.segment_id-1)}, segment_end: {segment_len*args.segment_id}")
_tqdm = tqdm(total=len(image_files) * len(harmful_types) * (args.sample_num if args.sample_num != -1 else len(df)))
with torch.no_grad():
    for image_file in image_files:
        out = []
        f_out = open(os.path.join(args.output_fold, image_file.split('.')[0] + '.jsonl'), 'w', encoding='utf-8')
        for harmful_type in harmful_types:
            data = df[df["types_of_harm"] == harmful_type]['question'].tolist()
            if args.sample_num != -1:
                # 每个类别选args.sample_num条:
                random.seed(args.segment_id)
                random.shuffle(data)
                data = data[:args.sample_num]

            text_formats = prompt_wrapper.minigpt4_chatbot_prompt
            img = Image.open(os.path.join(args.image_fold, image_file)).convert('RGB')
            img_prompt = vis_processor(img).unsqueeze(0).to(model.device)
            if args.random_attack:
                epsilon = 32 / 255
                adv_noise = torch.rand_like(img_prompt).to(model.device) * 2 * epsilon - epsilon
                img_prompt = visual_attacker.denormalize(img_prompt).clone().to(model.device)
                adv_noise.data = (adv_noise.data + img_prompt.data).clamp(0, 1) - img_prompt.data
                img_prompt = img_prompt + adv_noise
                img_prompt = visual_attacker.normalize(img_prompt)
            img_prompt = [img_prompt]

            prompt = prompt_wrapper.PreferencePrompt(model=model, img_prompts=[img_prompt])

            with torch.no_grad():
                for i, user_message in enumerate(data):
                    prompt.update_text_prompt(
                        [text_formats], [user_message]
                    )
                    response, _ = my_generator.generate(prompt)
                    f_out.write(
                        json.dumps({
                            'harm_type': harmful_type, 'prompt': user_message, 'response': response
                        }) + '\n'
                    )
                    _tqdm.update(1)
        f_out.close()
_tqdm.close()
