import os
import random
import argparse
import numpy as np
from tqdm import tqdm
import json
import time
from PIL import Image

import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms

from minigpt4.models import load_preprocess
from minigpt4.common.config import Config
from minigpt4.common.registry import registry

from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *

from decoder_zoo.VCD.vcd_utils.vcd_add_noise import add_diffusion_noise
from transformers import StoppingCriteriaList, MaxLengthCriteria
from decoder_zoo.Devils.modify_attention import llama_head_guide
from pycocotools.coco import COCO

import types
from forward_masking_llava import forward_masking


MODEL_EVAL_CONFIG_PATH = {
    "llava-1.5-7b": "eval_configs/llava-1.5_7b_eval.yaml",
    "llava-1.5-13b": "eval_configs/llava-1.5_13b_eval.yaml"
}

INSTRUCTION_TEMPLATE = {
    "llava-1.5-7b": "USER: <ImageHere>\n<question> ASSISTANT:",
    "llava-1.5-13b": "USER: <ImageHere>\n<question> ASSISTANT:"
}

class Hook:
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
        self.input = None
        self.output = None

    def hook_fn(self, module, input, output):
        self.input = input
        self.output = output

    def close(self):
        self.hook.remove()

def setup_seeds(config, seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # for single-GPU
    torch.cuda.manual_seed_all(seed)  # for multi-GPU
    cudnn.benchmark = False
    cudnn.deterministic = True

def save_args_to_json(args, config_filename='config.json'):
    args_dict = vars(args)
    with open(config_filename, 'w') as json_file:
        json.dump(args_dict, json_file, indent=4)


parser = argparse.ArgumentParser(description="Evaluation on LVLMs.")
parser.add_argument("--model", type=str, default="llava1.5", help="model")
parser.add_argument("--merged_ckpt", type=str, default=None)
parser.add_argument("--decoder", type=str, default="greedy", help="Decoding strategy to use. You can choose from 'greedy', 'vcd', 'opera', 'devils', 'pai'. Default is 'greedy'.")
parser.add_argument("--beam", type=int, default=1)
parser.add_argument("--sample", action="store_true")
parser.add_argument("--max_new_tokens", type=int, default=512)
parser.add_argument("--num_samples", type=int, default=500)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--options", nargs="+", help="override some settings in the used config, the key-value pair " "in xxx=yyy format will be merged into config file (deprecate), ""change to --cfg-options instead.",)
# data
parser.add_argument("--dataset_name", type=str, default="chair", help="Name of the dataset. Default is 'chair'.")
parser.add_argument("--image_folder", type=str, default="coco/val2014/", help="data path",)
parser.add_argument("--attack_folder", type=str, default="", help="attack images path",)
parser.add_argument("--caption_file_path", type=str, default="../dataset/coco/annotations/captions_val2014.json", help="Caption file of the dataset.")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--num_workers", type=int, default=2, help="num workers")
parser.add_argument("--output_dir", type=str, default="./log/", help="Output ditectory for saving test results. Default is './generated_chair_inputs/'.")
parser.add_argument("--verbosity", action="store_false", dest="verbosity", default=True, help="Verbosity. Default: True.")
# opera
parser.add_argument("--scale_factor", type=float, default=50)
parser.add_argument("--threshold", type=int, default=15)
parser.add_argument("--num_attn_candidates", type=int, default=5)
parser.add_argument("--penalty_weights", type=float, default=1.0)
# vcd
parser.add_argument("--cd_alpha", type=float, default=1, help="Alpha param for VCD.")
parser.add_argument("--cd_beta", type=float, default=0.1, help="Beta param for VCD.")
parser.add_argument("--noise_step", type=int, default=500, help="Noise step for VCD.")
# ours
parser.add_argument("--use_ours", action="store_true", help="Enable Ours")
parser.add_argument("--k_sig", type=float, default=1.1,)

args = parser.parse_known_args()[0]

device = torch.device("cuda:0")

# ========================================
#             Argument Initialization
# ========================================
model_name = args.model
# decoding params
decoding_strategy = args.decoder
num_beams = args.beam
sample = args.sample
max_new_tokens = args.max_new_tokens
output_dir = args.output_dir
verbosity = args.verbosity
# dataset params
dataset_name = args.dataset_name
image_folder = args.image_folder
num_workers = args.num_workers
batch_size = args.batch_size
num_samples = args.num_samples
# vcd params
noise_step = args.noise_step
cd_alpha = args.cd_alpha
cd_beta = args.cd_beta
# dola params
lm_early_exit_layers = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32]
mature_layer = lm_early_exit_layers[-1]
premature_layer = None
candidate_premature_layers = lm_early_exit_layers[:-1]
premature_layer_dist = {l: 0 for l in candidate_premature_layers}
# opera params
scale_factor=args.scale_factor
threshold=args.threshold
num_attn_candidates=args.num_attn_candidates
penalty_weights=args.penalty_weights
# devil params
guide_layer_range = (5, 18)
alpha=0.5

# ========================================
#             Model Initialization
# ========================================
args.cfg_path = MODEL_EVAL_CONFIG_PATH[args.model]
cfg = Config(args)
seed = args.seed
setup_seeds(cfg, seed)

model_config = cfg.model_cfg
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to(device).to(torch.bfloat16)
model.eval()

processor_cfg = cfg.get_config().preprocess
processor_cfg.vis_processor.eval.do_normalize = False
vis_processors, txt_processors = load_preprocess(processor_cfg)

os.makedirs(args.output_dir, exist_ok=True)
save_args_to_json(args, os.path.join(args.output_dir, 'config.json'))

# ========================================
#    Initializing decoding strategy
# ========================================
valid_decoding_strategies = [
    "greedy",
    "dola",
    "opera",
    "vcd",
    "beam",
    "devils",
    "pai",
]

decoding_strategy = decoding_strategy
opera_decoding = False
dola_decoding = False
halc_decoding = False
vcd_decoding = False
beam_search = False
pai_decoding = False
devils_decoding = False

stopping_criteria = None
output_attentions = False
if decoding_strategy == "greedy":
    pass
elif decoding_strategy == "dola":
    dola_decoding = True
elif decoding_strategy == "opera":
    beam_search = True
    opera_decoding = True
    num_beams = 5
    output_attentions = True
    stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=args.max_new_tokens)])
elif decoding_strategy == "beam":
    beam_search = True
elif decoding_strategy == "vcd":
    vcd_decoding = True
elif decoding_strategy == "pai":
    pai_decoding = True
elif decoding_strategy == "devils":
    devils_decoding = True

# ==================================================
### PROPOSED
# ==================================================
use_ours_cd = False
if args.use_ours:
    print("Using Ours!! k_sig:", args.k_sig)
    layer_num = 10
    k_sig = args.k_sig
    original_set = []
    if 'llava' in args.model:
        vit_enc_lst = [Hook(model.llama_model.model.vision_tower.vision_tower.vision_model.encoder.layers[i]) for i in range(layer_num)]
        for i in range(22):
            original_set.append(model.llama_model.model.vision_tower.vision_tower.vision_model.encoder.layers[i].self_attn.forward)
        for idx, layer in enumerate(model.llama_model.model.vision_tower.vision_tower.vision_model.encoder.layers):
            layer.self_attn.masking = None
    attack_image_fol = args.attack_folder
    if vcd_decoding:
        use_ours_cd = True
        print("Using Ours with VCD")
# ==================================================

print(f"\033[42m####### Current Decoding Strategy: {decoding_strategy} #######\033[0m")

if verbosity:
    print("\ndecoding strategy: ", decoding_strategy)
    print("backbone model_name: ", args.model)
    print("dataset_name: ", dataset_name)
    print("image_folder: ", image_folder)
    print("output_dir: ", output_dir)
    print("num_samples: ", num_samples)
    print("num_beams: ", num_beams)
    print("seed: ", seed)
    print(vis_processors["eval"].transform)

if decoding_strategy == "devils":
    if "llava" in args.model:
        guide_layer_range = (5, 18)
        alpha = 0.5

# ========================================
#    Initializing dataset
# ========================================
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
norm = transforms.Normalize(mean, std)

if dataset_name == 'chair':
    caption_file_path = args.caption_file_path
    coco = COCO(caption_file_path)
    img_ids = coco.getImgIds()
    random.shuffle(img_ids)
    sampled_img_ids = img_ids[:args.num_samples] if args.num_samples > 0 else img_ids
    questions = []
    for sampled_img_id in sampled_img_ids:
        image_file = coco.loadImgs(sampled_img_id)[0]["file_name"]
        question = {
            "question_id": sampled_img_id,
            "image": image_file,
            "text": "Please describe this image in detail.",
        }
        questions.append(question)

elif dataset_name == 'pope':
    caption_file_path = args.caption_file_path
    with open(caption_file_path, 'r', encoding='utf-8') as f:
        pope_data = [json.loads(line.strip()) for line in f.readlines()]
    if args.num_samples > 0:
        pope_data = pope_data[:args.num_samples]
    
    questions = []
    for entry in pope_data:
        question = {
            "question_id": entry["question_id"],
            "image": entry["image"],
            "text": entry["text"],
        }
        questions.append(question)

elif dataset_name == "amber":
    generative_path = os.path.join(args.caption_file_path, "query_generative.json")
    existence_path = os.path.join(args.caption_file_path, "query_discriminative-existence.json")
    
    with open(generative_path, "r", encoding="utf-8") as f:
        generative_data = json.load(f)

    with open(existence_path, "r", encoding="utf-8") as f:
        existence_data = json.load(f)

    amber_data = generative_data + existence_data

    questions = []
    for d in amber_data:
        questions.append({
            "question_id": d["id"],
            "image": d["image"],
            "image_path": os.path.join(image_folder, d["image"]),
            "text": d["query"]
        })

answers_file = os.path.join(args.output_dir, 'captions.jsonl')

# ========================================
#    Generating Answers
# ========================================
ans_file = open(answers_file, "w")
start_time = time.time()

for line in tqdm(questions, total=len(questions)):
    question_id = line["question_id"]
    cur_prompt = line["text"]
    image_file = line["image"]
    image_path = line["image_path"] if "image_path" in line else os.path.join(image_folder, image_file)
    raw_image = Image.open(image_path).convert('RGB')

    template = INSTRUCTION_TEMPLATE[args.model]
    prompt = template.replace("<question>", cur_prompt)
    image = vis_processors["eval"](raw_image).unsqueeze(0)
    image = image.to(device, torch.bfloat16)
    image_cd = None

    # vcd_decoding
    if vcd_decoding:
        image_tensor_cd = add_diffusion_noise(image, noise_step)
        image_cd = (
            image_tensor_cd.unsqueeze(0).to(torch.bfloat16).cuda()  # half()
            if image_tensor_cd is not None
            else None
        )

    # devils_decoding
    if devils_decoding:
        # Devils in the middle layers
        llama_head_guide(
            model=model.llama_model,
            guided_layer_range=guide_layer_range,
            aggregation="mean",
            alpha=alpha,
            img_start_idx=model.get_vision_start_index(),  
            img_end_idx=model.get_vision_end_index(),      
        )

    # ==================================================
    ### PROPOSED
    # ==================================================
    ours_mask_cd = None
    if args.use_ours:
        attack_image_path = os.path.join(attack_image_fol, image_file.replace('jpg', 'png'))
        attack_image_path = attack_image_path.replace('\\', '/')

        attack_raw_image = Image.open(attack_image_path).convert('RGB')
        attack_image = vis_processors["eval"](attack_raw_image).unsqueeze(0).to(device, torch.bfloat16)
        if 'llava' in args.model:
            for ii in range(22):
                model.llama_model.model.vision_tower.vision_tower.vision_model.encoder.layers[ii].self_attn.masking = None
                model.llama_model.model.vision_tower.vision_tower.vision_model.encoder.layers[ii].self_attn.forward = original_set[ii]
            with torch.no_grad():
                _ = model.llama_model.model.vision_tower(norm(image))
                orig_vitenc_lst = [vit_enc_lst[i].input[0].squeeze(0) for i in range(layer_num)]
                _ = model.llama_model.model.vision_tower(norm(attack_image))
                attack_vitenc_lst = [vit_enc_lst[i].input[0].squeeze(0) for i in range(layer_num)]
        mask_sum = 0
        for ii in range(layer_num):
            diff = attack_vitenc_lst[ii] - orig_vitenc_lst[ii]
            normed_diff = (diff.norm(dim=1) - diff.norm(dim=1).min()) / (diff.norm(dim=1).max() - diff.norm(dim=1).min())
            mask_sum += normed_diff
        mask_sum /= layer_num
        thr = mask_sum.mean() + k_sig * mask_sum.std(unbiased=False)

        if vcd_decoding:
            ours_mask_cd = mask_sum < thr
        else:
            if 'llava' in args.model:
                for ii in range(22):
                    model.llama_model.model.vision_tower.vision_tower.vision_model.encoder.layers[ii].self_attn.forward = types.MethodType(forward_masking, model.llama_model.model.vision_tower.vision_tower.vision_model.encoder.layers[ii].self_attn)
                for kk in range(12, 17):
                    model.llama_model.model.vision_tower.vision_tower.vision_model.encoder.layers[kk].self_attn.masking = mask_sum < thr
    # ==================================================
    
    with torch.inference_mode():
        with torch.no_grad():
            out = model.generate(
            {"image": norm(image), "prompt": prompt, "img_path": image_path},
            output_attentions=output_attentions,
            # Decoding
            num_beams=num_beams,
            max_new_tokens=max_new_tokens,
            use_nucleus_sampling=sample,
            beam_search=beam_search,
            dola_decoding=dola_decoding,
            opera_decoding=opera_decoding,
            vcd_decoding=vcd_decoding,
            halc_decoding=halc_decoding,
            # DOLA
            premature_layer=premature_layer,
            candidate_premature_layers=candidate_premature_layers,
            mature_layer=mature_layer,
            # OPERA
            key_position=None,
            scale_factor=scale_factor,
            threshold=threshold,
            num_attn_candidates=num_attn_candidates,
            penalty_weights=penalty_weights,
            # VCD
            images_cd=image_cd,
            cd_alpha=cd_alpha,
            cd_beta=cd_beta,
            use_ours_cd=use_ours_cd,
            ours_mask_cd=ours_mask_cd,
            forward_masking=forward_masking,
            # PAI
            pai_decoding=pai_decoding,
            # Devils
            devils_decoding=devils_decoding,
        )
    output_text = out[0]
    
    # remove unk
    sentence_list = output_text.split(".")
    sentence_filter_list = []
    for sentence in sentence_list:
        if "unk" not in sentence:
            sentence_filter_list.append(sentence)
    output_text = ".".join(sentence_filter_list)

    # save results
    ans_file.write(json.dumps({"question_id": question_id,
                                "image": image_file,
                                "prompt": cur_prompt,
                                "text": output_text,
                                "model_id": model_name}) + "\n")    
    ans_file.flush()

if dataset_name == 'pope':
     # Load label source file again
    with open(args.caption_file_path, 'r', encoding='utf-8') as f:
        txt_labels = [json.loads(line.strip()) for line in f.readlines()]
    if args.num_samples > 0:
        txt_labels = txt_labels[:args.num_samples]

    eval_results = []
    labels = []
    answer_len = []

    with open(answers_file, 'r', encoding='utf-8') as f:
        generated = [json.loads(l) for l in f.readlines()]
    for eval_prompt, gen in zip(txt_labels, generated):
        image_name = eval_prompt['image']
        text = eval_prompt['text']
        answer = gen["text"].lower()

        if "assistant:" in answer:
            answer = answer.split("assistant:")[-1]
        answer = answer.replace("\n", " ").strip().lower()
        answer_len.append(len(answer))
        real_answer = None
        if "yes" in answer:
            real_answer = "yes"
        elif "no" in answer:
            real_answer = "no"

        inference_result = {
            "image_path": image_name,
            "question": text,
            "answer": real_answer,
            "model_answer": answer
        }
        eval_results.append(inference_result)
        labels.append({"label": eval_prompt['label']})

    save_prefix = os.path.join(args.output_dir, 'output')

    os.makedirs(os.path.dirname(save_prefix), exist_ok=True)
    with open(f'{save_prefix}.json', 'w', encoding='utf-8') as f:
        for entry in eval_results:
            json_line = json.dumps(entry)
            f.write(json_line + '\n')

    with open(f'{save_prefix}_label.json', 'w', encoding='utf-8') as f:
        for entry in labels:
            json_line = json.dumps(entry)
            f.write(json_line + '\n')

end_time = time.time()
print(decoding_strategy, end_time - start_time)
