import json
from tqdm import tqdm
import os
import datetime
import random
import re
import torch

from datasets import load_dataset
import argparse
from swift.llm import (
        get_model_tokenizer, get_template, inference, ModelType,
        get_default_template_type, inference_stream
    )
from swift.utils import seed_everything
import logging
from modelscope import (
    snapshot_download, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
)
from evaluate.seed_bench import SEED_Bench
from evaluate.pope import  POPE
from evaluate.mme import MME
from evaluate.mm_bench import MM_Bench
from evaluate.cobench_yn import CoBench_YN
from evaluate.cobench_ch import CoBench_CH
from evaluate.mmstar import MMstar
from evaluate.scienceQA import ScienceQA
from evaluate.MathVista import MathVista
from evaluate.MMMU_val import MMMU_val
from evaluate.AI2D import AI2D
from evaluate.test_dataset import test_dataset
from swift.tuners import Swift

if __name__ == '__main__':
        
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_image_path', type=str, default='')
    parser.add_argument('--model_type', type=str, default='glm4v-9b-chat')
    parser.add_argument('--dataset', type=str, default='MME')
    parser.add_argument('--do_sample', type=int, default= 1 )
    parser.add_argument('--num_sample', type=int, default= 10 )
    parser.add_argument('--api_model', type=str, default='')
    parser.add_argument('--tempeature', type=float, default=0.1 )

    parser.add_argument('--local_fintune_model_path', type=str, default='')
    args = parser.parse_args()

    ckpt_dir = args.local_fintune_model_path
    template_type = get_default_template_type(args.model_type)
    model_id_or_path = None
    model, tokenizer = get_model_tokenizer(args.model_type, model_id_or_path=model_id_or_path, model_kwargs={'device_map': 'auto'})
    model = Swift.from_pretrained(model, ckpt_dir, inference_mode=True)
    model.generation_config.max_new_tokens = 256
    template = get_template(template_type, tokenizer)
    seed_everything(1)
    
    try:
        if args.dataset == 'SEED-Bench':
            dataset_name = "lmms-lab/SEED-Bench"
            data_dir = "/dataset/SEED-Bench"
            dataset = load_dataset(dataset_name, cache_dir=data_dir)
            val_data = dataset['test']
        elif args.dataset == 'MMStar':
            dataset = load_dataset("Lin-Chen/MMStar", "val")
            val_data = dataset["val"]
        elif args.dataset == 'MMMU':
            dataset_name = "lmms-lab/MMMU"
            data_dir = "/dataset/MMU"
            dataset = load_dataset(dataset_name, cache_dir=data_dir)
            val_data = dataset["validation"]
        elif args.dataset == 'MME':
            data = load_dataset("lmms-lab/MME")
            val_data = data["test"]
        elif args.dataset == 'MMBench-cc':
            data = load_dataset("lmms-lab/MMBench", 'cc')
            val_data = data["test"]
        elif args.dataset == 'MMBench-en':
            data = load_dataset("lmms-lab/MMBench", 'en')
            val_data = data["test"]
        elif args.dataset == 'POPE':
            data = load_dataset("lmms-lab/POPE", "default")
            val_data = data["test"]
        elif args.dataset == 'ScienceQA':
            dataset = load_dataset("derek-thomas/ScienceQA")
            val_data = dataset['test']
            val_data = [item for item in val_data if item['image'] is not None]
        elif args.dataset == 'AI2D':
            dataset = load_dataset("lmms-lab/ai2d")
            val_data = dataset['test']
        elif args.dataset == 'MathVista':
            dataset_name = "AI4Math/MathVista"
            data_dir = "/dataset/MathVista"
            dataset = load_dataset(dataset_name, cache_dir=data_dir)
            val_data = dataset['testmini']
            val_data = [item for item in val_data if ((item['image'] is not None) and (item["question_type"] =='multi_choice'))]
        elif args.dataset == 'ConBench_CH':
            dataset_name = "ConBench/ConBench_D"
            data_dir = "/dataset/ConBench"
            dataset = load_dataset(dataset_name, cache_dir=data_dir)
            val_data = dataset['test']
        elif args.dataset == 'ConBench_YN':
            dataset_name = "ConBench/ConBench_D"
            data_dir = "/dataset/ConBench"
            dataset = load_dataset(dataset_name, cache_dir=data_dir)
            val_data = dataset['test']
            val_data = [item for item in val_data if item['image'] is not None]
        elif args.dataset == 'test_dataset':
            path='/View/result/all_data_benchmark/extract_11_items.jsonl'
            with open(path, 'r', encoding='utf-8') as f:
                try:
                    content = f.read()
                    val_data = json.loads(content)
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON in file {path}: {e}")
        else:
            raise NotImplementedError
        
    except Exception as e:
        print("")
        print(f"Error occurred: {e}. Loading 3000 samples dataset.")

    if args.do_sample:
        subset = []
        for i in range(args.num_sample):
            subset.append(val_data[i])
        val_data = subset
    else:
        val_data = val_data

    current_time = datetime.datetime.now()
    month_day = current_time.strftime("%m%d")
    save_directory = f"/Uncertainty_MLLMs/UVLM/output_{month_day}/{args.dataset}/output_{args.model_type}"
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)
    
    log_filename = f'{save_directory}/{args.model_type}.txt'
    logging.basicConfig(filename=log_filename, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

    if args.dataset == 'SEED-Bench':
        SEED_Bench(args,val_data, model, template)
    elif args.dataset == 'POPE':
        POPE(args,val_data, model, template )
    elif args.dataset == 'MME':
        MME(args,val_data, model, template )
    elif args.dataset == 'MMBench-cc':
        MM_Bench(args,val_data, model, template )
    elif args.dataset == 'MMStar':
        MMstar(args,val_data, model, template )
    elif args.dataset == 'MMMU':
        MMMU_val(args,val_data, model, template )
    elif args.dataset == 'AI2D':
        AI2D(args,val_data, model, template )
    elif args.dataset == 'MathVista':
        MathVista(args,val_data, model, template )
    elif args.dataset == 'ScienceQA':
        ScienceQA(args,val_data, model, template )
    elif args.dataset == 'ConBench_CH':
        val_data1 = []
        for item in val_data:
            if item['question_field'] == 'Choices':
                val_data1.append(item)
        CoBench_CH(args,val_data1, model, template )
    elif args.dataset == 'ConBench_YN':
        val_data1 = []
        for item in val_data:
            if item['question_field'] == 'N/Y':
                val_data1.append(item)
        CoBench_YN(args,val_data1, model, template )
    elif args.dataset == 'test_dataset':
        test_dataset(args,val_data, model, template )
    else:
        raise NotImplementedError
