import os
from imagenet_util import ImageNetDataset
import json
import numpy as np
from tqdm import tqdm
import argparse
import torch
from transformers import (
    MllamaForConditionalGeneration, Qwen2VLForConditionalGeneration, AutoProcessor,
    Qwen2_5_VLForConditionalGeneration
)
from utils import get_model, get_ensemble_model_func
from openai import OpenAI
from OOD_prompts import (
    imagenet_c_prompt_direct, imagenet_c_prompt_with_caption, caption_str,
    ood_prompt_direct, ood_prompt_with_caption, ood_prompt_direct_2, ood_prompt_with_caption_2,
    ood_prompt_simple
)

unknown_instruction = """
**Notice that** if the image *does not* clearly belong to any of the 10 classes provided, classify it as "unknown".
"""

prompt_map = {
    'direct': imagenet_c_prompt_direct + unknown_instruction,
    'caption_and_answer': imagenet_c_prompt_with_caption + unknown_instruction,
    'caption_only': caption_str,
    'direct_2': ood_prompt_direct,
    'caption_and_answer_2': ood_prompt_with_caption,
    'direct_3': ood_prompt_direct_2,
    'caption_and_answer_3': ood_prompt_with_caption_2,
    'simple': ood_prompt_simple,
}

parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default='llama', choices=[
        'llama', 'qwen', 'gpt', 'qwen_new', 'qwen_new_small', 'qwen_big_budget'
])
parser.add_argument("--prompt_mode",
    type=str, default='direct', choices=[
    'direct', 'caption_and_answer',
    'caption_only', 'simple',
    'direct_2', 'caption_and_answer_2',
    'direct_3', 'caption_and_answer_3'
    ]
)
args = parser.parse_args()
model_name = args.model

dataset = ImageNetDataset(
   '',
)

N = 3000
in_dist_size = int(N * 0.6)
ood_size = int(N * 0.4)
np.random.seed(10)

ood_imgs_idx = np.load('')
imgs_in_cifar10_idx = np.load('')

in_distribution_idx = np.random.choice(imgs_in_cifar10_idx, size=in_dist_size, replace=False)
ood_idx = np.random.choice(ood_imgs_idx, size=ood_size, replace=False)

model, processor, model_answer_func = get_model(model_name)


actual_prompt = prompt_map[args.prompt_mode]
all_responses = []

for i, idx in tqdm(
        enumerate(
            np.concatenate([in_distribution_idx, ood_idx])
        )
    ):
    img, label = dataset[idx]
    img = (img.permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype('uint8')
    response = model_answer_func(model, processor, actual_prompt, img)
    ood_label = 0 if i < in_dist_size else 1
    response_dict = {
        'img_idx': int(idx),
        'label': int(label),
        'response': response,
        'ood_label': int(ood_label),
    }
    all_responses.append(response_dict)

with open(f"./model_outputs/imagenet_OOD_results/{model_name}_{args.prompt_mode}.json", "w") as f:
    json.dump(all_responses, f)