import argparse
import torch
import os
import json
from tqdm import tqdm
import shortuuid
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# print(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

from PIL import Image, ImageFilter, ImageOps
import PIL.ImageOps
import math
import cv2
import numpy as np

import hydra
from omegaconf import DictConfig, OmegaConf

from eval.config import Config

from torchvision import transforms

from collections import Counter

# import kornia
from transformers import set_seed
from lavis.models import load_model_and_preprocess

from vacode_utils.diffusion_noise import add_diffusion_noise
from vacode_utils.vacode_sample import evolve_vacode_sampling

evolve_vacode_sampling()


@hydra.main(version_base=None, config_path="../config/llava", config_name="main")
def main(args: DictConfig):
    cfg = Config.from_omegaconf(args)
    set_seed(cfg.seed)

    eval_model(cfg)


def augment_image(image, image_tensor, aug_type, vis_processors, args):
    if aug_type == "color":
        image_cd = PIL.ImageOps.invert(image)
    elif aug_type == "edge":
        image_cd = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
        gray = cv2.cvtColor(image_cd, cv2.COLOR_BGR2GRAY)
        blur = cv2.GaussianBlur(gray, (3, 3), 0)
        edges = cv2.Canny(blur, threshold1=100, threshold2=200)
        edges_colored = cv2.cvtColor(edges, cv2.COLOR_GRAY2BGR)
        image_cd = Image.fromarray(cv2.cvtColor(edges_colored, cv2.COLOR_BGR2RGB))
    elif aug_type == "sharp":
        image_cd = transforms.RandomAdjustSharpness(10, p=1)(image)
    elif aug_type == "crop":
        image_size = (image.size[1], image.size[0])
        image_cd = transforms.RandomResizedCrop(image_size)(image)
    elif aug_type =="erase":
        tensor_image = transforms.ToTensor()(image)
        image_cd = transforms.RandomErasing(p=1)(tensor_image)
        image_cd = transforms.ToPILImage()(image_cd)
    elif aug_type =="flip":
        image_cd = image.transpose(PIL.Image.FLIP_TOP_BOTTOM)
        image_cd = image_cd.transpose(PIL.Image.FLIP_LEFT_RIGHT)
    elif aug_type =="noise":
        image_tensor_cd = add_diffusion_noise(image_tensor, args.noise_step)
    elif aug_type =="nocd":
        image_tensor_cd = None   
    else:
        raise ValueError(f"Exp type {aug_type} not found")
    
    if aug_type != "noise" and aug_type != "regular":
        image_tensor_cd = vis_processors["eval"](image_cd).unsqueeze(0).to(image_tensor.device)

    return image_tensor_cd


def eval_model(args):

    disable_torch_init()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_name = 'instruct_blip'
    model, vis_processors, txt_processsors = load_model_and_preprocess(name="blip2_vicuna_instruct", model_type=args.model_path, is_eval=True, device=device)

    questions = [json.loads(q) for q in open(os.path.expanduser(args.question_path), "r")]
    answers_file = os.path.expanduser(args.save_path + "_answers.jsonl")
    os.makedirs(os.path.dirname(answers_file), exist_ok=True)
    ans_file = open(answers_file, "w")
    
    for line in tqdm(questions):
        idx = line["question_id"]
        image_file = line["image"]
        
        if args.dataset == "VQAv2":
            qs = line["question"]
        else:
            qs = line["text"]
            
        if args.dataset == "MME":
            prompt = qs
            prompt = prompt.replace("\nAnswer the question using a single word or phrase.", "")
            prompt = prompt + " Answer the question using yes or no."
        else:
            prompt = qs +  " Please answer this question with one word."

        image = Image.open(os.path.join(args.image_path, image_file)).convert("RGB")
        image_tensor = vis_processors["eval"](image).unsqueeze(0).to(device)

        # allow multiple augmentations
        exp_type_list = args.exp_type.split("-")
        image_tensor_cd = []
        for aug_type in exp_type_list:
            image_tensor_cd_ele = augment_image(image, image_tensor, aug_type, vis_processors, args)
            if image_tensor_cd_ele is not None:
                image_tensor_cd.append(image_tensor_cd_ele) 
        if len(image_tensor_cd) == 0:
            image_tensor_cd = None

        with torch.inference_mode():
            outputs = model.generate(
                {"image": image_tensor, "prompt": prompt},
                use_nucleus_sampling=True, num_beams=1,
                temperature=args.temperature,
                top_p=args.top_p,
                repetition_penalty=1,
                images_cd=image_tensor_cd,
                cd_alpha = args.cd_alpha,
                cd_beta = args.cd_beta,
            )

        outputs = outputs[0]

        ans_file.write(json.dumps({"question_id": idx,
                                    "prompt": prompt,
                                    "text": outputs,
                                    "model_id": model_name,
                                    "image": image_file,
                                    "metadata": {}}) + "\n")
        
        ans_file.flush()
    ans_file.close()

if __name__ == "__main__":
    main()
