import os
from imagenet_util import ImageNetCDataset, ImageNetDataset
import json
import numpy as np
from tqdm import tqdm
import argparse
import torch
from utils import get_model, get_ensemble_model_func
from pixmix_loader import PixMixDataset

from OOD_prompts import (
    imagenet_c_prompt_simple,
    caption_str, imagenet_c_prompt_direct,
    imagenet_c_prompt_with_caption,
    imagenet_c_prompt_with_caption_3,
    imagenet_c_prompt_with_caption_5,
    imagenet_c_prompt_with_caption_10,
    imagenet_c_prompt_use_caption_direct,
    imagenet_c_prompt_use_caption_cot,
    imagenet_c_prompt_use_caption_cot_with_rej,
    imagenet_c_prompt_use_multi_caption_cot,
    imagenet_c_prompt_use_multi_caption_cot_with_rej
)

parser = argparse.ArgumentParser()
parser.add_argument('--corruption_type', type=str, default='gaussian_noise')
parser.add_argument('--corruption_level', type=int, default=1)
parser.add_argument('--allow_rejection', action='store_true', default=False)
parser.add_argument("--model", type=str, default='llama', choices=['llama', 'qwen', 'gpt',
                                                                   'llama_big', 'qwen_big',
                                                                   'qwen_new', 'qwen_new_small',
                                                                   'qwen_big_budget'])
parser.add_argument("--sample_seed", type=int)
parser.add_argument("--prompt_mode", default='direct',
                    choices=['direct', 'caption',
                             'caption_and_answer',
                             'caption_and_answer_3',
                             'caption_and_answer_5',
                             'caption_and_answer_10',
                             'corrupted_img_and_clean_caption_direct',
                             'corrupted_img_and_clean_caption_cot',
                             'clean_img_and_corrupted_caption_direct',
                             'clean_img_and_corrupted_caption_cot',
                             'simple', 'ensemble'
                             ])
parser.add_argument("--ensemble_num", type=int, default=3)
args = parser.parse_args()

corruption_type = args.corruption_type
corruption_level = args.corruption_level
allow_rejection = args.allow_rejection
model_name = args.model
promt_mode = args.prompt_mode
ensemble_num = args.ensemble_num

model, processor, model_answer_func = get_model(model_name)

test_size = 1000

if 'clean_img_and_corrupted_caption' in promt_mode:
    dataset = ImageNetDataset(
        '',
    )
elif corruption_level == 0:
    dataset = ImageNetDataset(
        '',
    )
    corruption_type = 'clean'
elif corruption_type == 'pixmix':
    dataset = PixMixDataset(intenstity=corruption_level)
else:
    dataset = ImageNetCDataset(
        '',
        corruption_type=corruption_type,
        severity=corruption_level,
    )

imgs_in_cifar10_idx = np.load('')

np.random.seed(10)
in_distribution_idx = np.random.choice(imgs_in_cifar10_idx, size=test_size, replace=False)

if promt_mode == 'simple':
    actual_prompt = imagenet_c_prompt_simple
elif promt_mode == 'direct':
    actual_prompt = imagenet_c_prompt_direct
elif promt_mode == 'caption':
    actual_prompt = caption_str
elif promt_mode == 'caption_and_answer':
    actual_prompt = imagenet_c_prompt_with_caption
elif promt_mode == 'caption_and_answer_3':
    actual_prompt = imagenet_c_prompt_with_caption_3
elif promt_mode == 'caption_and_answer_5':
    actual_prompt = imagenet_c_prompt_with_caption_5
elif promt_mode == 'caption_and_answer_10':
    actual_prompt = imagenet_c_prompt_with_caption_10
elif promt_mode in ['clean_img_and_corrupted_caption_direct', 'corrupted_img_and_clean_caption_direct']:
    actual_prompt = imagenet_c_prompt_use_caption_direct
elif promt_mode in ['clean_img_and_corrupted_caption_cot', 'corrupted_img_and_clean_caption_cot']:
    actual_prompt = imagenet_c_prompt_use_caption_cot
elif promt_mode == 'ensemble':
    actual_prompt = imagenet_c_prompt_use_multi_caption_cot
else:
    raise ValueError('Invalid prompt mode')

if allow_rejection:
    if promt_mode == 'clean_img_and_corrupted_caption_cot' or promt_mode == 'corrupted_img_and_clean_caption_cot':
        actual_prompt = imagenet_c_prompt_use_caption_cot_with_rej
    elif promt_mode == 'ensemble':
        actual_prompt = imagenet_c_prompt_use_multi_caption_cot_with_rej
    else:
        actual_prompt += '\nNotice that if you find an image very ambigious and cannot confidently classify it, return "unknown" as the label.'
    rejection_string = '_rejection'
else:
    rejection_string = ''

all_responses = []
log_file_template = './model_outputs/imgnetc_results/{}_{}_{}_{}{}.json'

if 'ensemble' in promt_mode:
    promt_mode = f'ensemble_{ensemble_num}_'
    if args.sample_seed is not None:
        promt_mode += f'seed_{args.sample_seed}_'
        torch.manual_seed(args.sample_seed)

log_file = log_file_template.format(
    model_name, corruption_type,
    corruption_level, promt_mode,
    rejection_string
)

if 'clean_img_and_corrupted_caption' in promt_mode:
    caption_file = log_file_template.format(
        model_name, corruption_type,
        corruption_level, 'caption',
        ''
    )
    img_captions = json.load(open(caption_file))
elif 'corrupted_img_and_clean_caption' in promt_mode:
    caption_file = log_file_template.format(
        model_name, 'clean',
        '0', 'caption',
        ''
    )
    img_captions = json.load(open(caption_file))

for i, idx in tqdm(
        enumerate(
            in_distribution_idx
        )
    ):
    img, label = dataset[idx]
    img = (img.permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype('uint8')
    extra_info = ''
    if 'clean_img_and_corrupted_caption' in promt_mode or 'corrupted_img_and_clean_caption' in promt_mode:
        caption = img_captions[i]['response']
        if isinstance(caption, list):
            caption = caption[0]
        _actual_prompt = actual_prompt.format(caption)
    elif 'ensemble' in promt_mode:
        multi_cap_gen_func = get_ensemble_model_func(model_name)
        all_captions = multi_cap_gen_func(model, processor,
            caption_str, img, ensemble_num,
            do_sample=True, temperature=0.6, top_p=0.95, top_k=50
        )
        all_captions = '\n'.join(['- ' + o for o in all_captions])
        _actual_prompt = actual_prompt.format(
            ensemble_num, all_captions
        )
        extra_info = all_captions
    else:
        _actual_prompt = actual_prompt
    response = model_answer_func(model, processor, _actual_prompt, img)
    response_dict = {
        'img_idx': int(idx),
        'label': int(label),
        'response': response,
        'extra_info': extra_info
    }
    all_responses.append(response_dict)

with open(log_file, 'w') as f:
    json.dump(all_responses, f)
