import argparse, json, os
import torch
import numpy as np

from torch.utils.data import DataLoader
from src_files.utils.logger import setup_logger
from src_files.utils.helper import get_raw_dict

from models import MODEL_CLASSES

from data.mlc import MLCDataset


DATA_CLASS = {
    'mlc': MLCDataset
}

MODEL_PATH = {
    "Qwen/Qwen2.5-VL-7B-Instruct": "qwen2_5vl_7b",
    "Qwen/Qwen2-VL-7B-Instruct": "qwen2vl_7b",
    "OpenGVLab/InternVL2_5-8B": "internvl2_5_8b",
    "OpenGVLab/InternVL3-8B": "internvl3_8b"
}

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_type', choices=MODEL_CLASSES.keys(), required=True)
    parser.add_argument('--model_path', required=True)
    parser.add_argument('--data_name', default='coco2014')
    parser.add_argument('--data_type', default='mlc')
    parser.add_argument('--data_path', default='')
    parser.add_argument('--output', default='./output')
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--start_id', type=int, required=True)
    parser.add_argument('--end_id', type=int, required=True)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--print_freq', type=int, default=1)
    return parser.parse_args()

def main():
    args = get_args()
    if args.seed is not None:
        torch.manual_seed(args.seed)

    args.output = os.path.join(args.output, args.data_name, f'{MODEL_PATH[args.model_path]}_{args.data_type}_{args.seed}')
    args.log_output = os.path.join(args.output, f"{args.start_id}_{args.end_id}")
    os.makedirs(args.log_output, exist_ok=True)

    logger = setup_logger(args.log_output, color=False, name="LVLM")
    logger.info("Command: " + ' '.join(os.sys.argv))
    json.dump(get_raw_dict(args), open(os.path.join(args.output, "config.json"), 'w'), indent=2)

    ModelClass = MODEL_CLASSES[args.model_type]
    if ModelClass is None:
        raise ImportError(f"Model '{args.model_type}' is not available in this environment.")
    model_runner = ModelClass(model_type=args.model_type, model_path=args.model_path)
    model_runner.load_model_and_processor()

    DataClass = DATA_CLASS[args.data_type]
    dataset = DataClass(args.data_name, args.data_path)
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    if args.data_name in ['coco2014', 'coco2017']:
        cls_names = load_cls_names(os.path.join(args.data_path, args.data_name, f'{args.data_name}_cls_names.txt'))
    elif args.data_name == 'objects365':
        cls_names = load_cls_names('objects365_cls_names.txt')

    for cls_id in range(args.start_id, args.end_id+1):
        cls_name = cls_names[cls_id]
        cls_answers = []
        for i, imgs in enumerate(loader):
            batch = {"image":imgs, "text": [f"Please answer yes or no. Is there a {cls_name} in this image?"]*len(imgs)}
            batch_qa = model_runner.run_batch_inference(batch)
            batch_answers = batch_qa['answer']
            cls_answers.extend(batch_answers)
            if i % args.print_freq == 0:
                logger.info(f"[{i}/{len(loader)}] {batch_qa}")

        np.save(os.path.join(args.output, f"answer_{cls_id}.npy"), np.array(cls_answers))

def load_cls_names(file_path):

    classes = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            classes.append(line.strip().lower())

    return classes

if __name__ == "__main__":
    main()
