import argparse
import torch
import os
import json
from tqdm import tqdm
import shortuuid
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# print(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

from PIL import Image, ImageFilter
import PIL.ImageOps
import math
import cv2
import numpy as np

import hydra
from omegaconf import DictConfig, OmegaConf

from eval.config import Config

from torchvision import transforms

from collections import Counter

# import kornia
from transformers import set_seed

from vacode_utils.diffusion_noise import add_diffusion_noise
from vacode_utils.vacode_sample import evolve_vacode_sampling

evolve_vacode_sampling()


@hydra.main(version_base=None, config_path="../config/llava", config_name="main")
def main(args: DictConfig):
    cfg = Config.from_omegaconf(args)
    set_seed(cfg.seed)

    eval_model(cfg)

def augment_image(image, image_tensor, aug_type, image_processor, args):
    if aug_type == "color":
        image_cd = PIL.ImageOps.invert(image)
    elif aug_type == "edge":
        image_cd = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
        gray = cv2.cvtColor(image_cd, cv2.COLOR_BGR2GRAY)
        blur = cv2.GaussianBlur(gray, (3, 3), 0)
        edges = cv2.Canny(blur, threshold1=100, threshold2=200)
        edges_colored = cv2.cvtColor(edges, cv2.COLOR_GRAY2BGR)
        image_cd = Image.fromarray(cv2.cvtColor(edges_colored, cv2.COLOR_BGR2RGB))
    elif aug_type == "sharp":
        image_cd = transforms.RandomAdjustSharpness(10, p=1)(image)
    elif aug_type == "crop":
        image_size = (image.size[1], image.size[0])
        image_cd = transforms.RandomResizedCrop(image_size)(image)
    elif aug_type =="erase":
        tensor_image = transforms.ToTensor()(image)
        image_cd = transforms.RandomErasing(p=1)(tensor_image)
        image_cd = transforms.ToPILImage()(image_cd)
    elif aug_type =="flip":
        image_cd = image.transpose(PIL.Image.FLIP_TOP_BOTTOM)
        image_cd = image_cd.transpose(PIL.Image.FLIP_LEFT_RIGHT)
    elif aug_type =="noise":
        image_tensor_cd = add_diffusion_noise(image_tensor, args.noise_step)
    elif aug_type =="nocd":
        image_tensor_cd = None   
    else:
        raise ValueError(f"Exp type {aug_type} not found")
    
    if aug_type != "noise" and aug_type != "regular":
        image_tensor_cd = image_processor.preprocess(image_cd, return_tensors='pt')['pixel_values'][0]
        
    return image_tensor_cd

def eval_model(args):

    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)

    questions = [json.loads(q) for q in open(os.path.expanduser(args.question_path), "r")]
    answers_file = os.path.expanduser(args.save_path + "_answers.jsonl")
    os.makedirs(os.path.dirname(answers_file), exist_ok=True)
    ans_file = open(answers_file, "w")
    
    for line in tqdm(questions):
        idx = line["question_id"]
        image_file = line["image"]
        
        if args.dataset == "VQAv2":
            qs = line["question"]
        elif args.dataset == "MME":
            qs = line["text"]
        else:
            qs = line["text"] + " Please answer this question with one word."
            
        cur_prompt = qs
        if model.config.mm_use_im_start_end:
            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

        conv = conv_templates[args.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

        image = Image.open(os.path.join(args.image_path, image_file))
        image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]

        # allow multiple augmentations
        exp_type_list = args.exp_type.split("-")
        image_tensor_cd = []
        for aug_type in exp_type_list:
            image_tensor_cd_ele = augment_image(image, image_tensor, aug_type, image_processor, args)
            image_tensor_cd_ele = image_tensor_cd_ele.unsqueeze(0).half().cuda() if image_tensor_cd_ele is not None else None
            if image_tensor_cd_ele is not None:
                image_tensor_cd.append(image_tensor_cd_ele) 
        if len(image_tensor_cd) == 0:
            image_tensor_cd = None

        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

        with torch.inference_mode():
            output_ids = model.generate(
                input_ids,
                images=image_tensor.unsqueeze(0).half().cuda(),
                images_cd=(image_tensor_cd if image_tensor_cd is not None else None),
                cd_alpha = args.cd_alpha,
                cd_beta = args.cd_beta,
                do_sample=args.do_sample,
                temperature=args.temperature,
                top_p=args.top_p,
                top_k=args.top_k,
                max_new_tokens=1024,
                use_cache=True)

        input_token_len = input_ids.shape[1]
        n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
        if n_diff_input_output > 0:
            print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
        outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
        outputs = outputs.strip()
        if outputs.endswith(stop_str):
            outputs = outputs[:-len(stop_str)]
        outputs = outputs.strip()

        ans_file.write(json.dumps({"question_id": idx,
                                    "prompt": cur_prompt,
                                    "text": outputs,
                                    "model_id": model_name,
                                    "image": image_file,
                                    "metadata": {}}) + "\n")
        ans_file.flush()
    ans_file.close()

if __name__ == "__main__":
    main()
