import sys
from collections import deque
from vlmeval.dataset import SUPPORTED_DATASETS
from vlmeval.config import *
from vlmeval.smp import *

# Define valid modes
MODES = ('dlist', 'mlist', 'missing', 'circular', 'localize', 'check', 'run', 'eval', 'merge_pkl', 'scan')

dataset_levels = {
    'l1': [
    ],
    'l2': [
    ],
    'l3': [
    ],
    'live': [
    ],
    'math': [
    ],
    'spatial': [
    ],
    'ESOV_GA': [
    ],
    'ESOV_GO': [
    ],
    'ESOV_R': [
    ],
    'ESOV_I': [
    ],
    'ESOV_S': [
    ]
}

dataset_levels['l12'] = dataset_levels['l1'] + dataset_levels['l2']
dataset_levels['l23'] = dataset_levels['l2'] + dataset_levels['l3']
dataset_levels['l123'] = dataset_levels['l12'] + dataset_levels['l3']

models = {
    '4.33.0': list(qwen_series) + list(xcomposer_series) + [
        'mPLUG-Owl2', 'flamingov2', 'VisualGLM_6b', 'MMAlaya', 'PandaGPT_13B', 'VXVERSE'
    ] + list(idefics_series) + list(minigpt4_series) + list(instructblip_series),
    '4.37.0': [x for x in llava_series if 'next' not in x] + list(internvl_series) + [
        'TransCore_M', 'emu2_chat', 'MiniCPM-V', 'MiniCPM-V-2', 'OmniLMM_12B',
        'cogvlm-grounding-generalist', 'cogvlm-chat', 'cogvlm2-llama3-chat-19B',
        'mPLUG-Owl3'
    ] + list(xtuner_series) + list(yivl_series) + list(deepseekvl_series) + list(janus_series) + list(cambrian_series),
    '4.36.2': ['Moondream1'],
    '4.40.0': [
        'idefics2_8b', 'Bunny-llama3-8B', 'MiniCPM-Llama3-V-2_5', '360VL-70B', 'Phi-3-Vision',
    ] + list(wemm_series),
    '4.44.0': ['Moondream2'],
    '4.45.0': ['Aria'],
    'latest': ['paligemma-3b-mix-448', 'MiniCPM-V-2_6', 'glm-4v-9b'] + [x for x in llava_series if 'next' in x]
    + list(chameleon_series) + list(ovis_series) + list(mantis_series),
    'api': list(api_models)
}

# SKIP_MODELS will be skipped in report_missing and run APIs
SKIP_MODELS = [
    'MGM_7B', 'GPT4V_HIGH', 'GPT4V', 'flamingov2', 'PandaGPT_13B',
    'GeminiProVision', 'Step1V-0701', 'SenseNova-V6',
    'llava_v1_7b', 'sharegpt4v_7b', 'sharegpt4v_13b',
    'llava-v1.5-7b-xtuner', 'llava-v1.5-13b-xtuner',
    'cogvlm-grounding-generalist', 'InternVL-Chat-V1-1',
    'InternVL-Chat-V1-2', 'InternVL-Chat-V1-2-Plus', 'RekaCore',
    'llava_next_72b', 'llava_next_110b', 'MiniCPM-V', 'sharecaptioner', 'XComposer',
    'VisualGLM_6b', 'idefics_9b_instruct', 'idefics_80b_instruct',
    'mPLUG-Owl2', 'MMAlaya', 'OmniLMM_12B', 'emu2_chat', 'VXVERSE'
] + list(minigpt4_series) + list(instructblip_series) + list(xtuner_series) + list(chameleon_series) + list(vila_series)

LARGE_MODELS = [
    'idefics_80b_instruct', '360VL-70B', 'emu2_chat', 'InternVL2-76B',
]


def completed(m, d, suf):
    score_file = f'outputs/{m}/{m}_{d}_{suf}'
    if osp.exists(score_file):
        return True
    return False


def DLIST(lvl):
    if lvl in dataset_levels.keys():
        return [x[0] for x in dataset_levels[lvl]]
    else:
        from vlmeval.dataset import SUPPORTED_DATASETS
        return SUPPORTED_DATASETS


def MLIST(lvl, size='all'):
    if lvl == 'all':
        from vlmeval.config import supported_VLM
        return [x for x in supported_VLM]

    model_list = models[lvl]
    if size == 'small':
        model_list = [m for m in model_list if m not in LARGE_MODELS]
    elif size == 'large':
        model_list = [m for m in model_list if m in LARGE_MODELS]
    return [x[0] for x in model_list]


def MISSING(lvl):
    from vlmeval.config import supported_VLM
    models = list(supported_VLM)
    models = [m for m in models if m not in SKIP_MODELS and osp.exists(osp.join('outputs', m))]
    if lvl in dataset_levels.keys():
        data_list = dataset_levels[lvl]
    else:
        data_list = [(D, suff) for (D, suff) in dataset_levels['l123'] if D == lvl]
    missing_list = []
    for f in models:
        for D, suff in data_list:
            if not completed(f, D, suff):
                missing_list.append((f, D))
    return missing_list


def CIRCULAR(inp):
    def proc_str(s):
        chs = set(s)
        chs = [x for x in chs if x not in string.ascii_letters and x != ' ']
        for ch in chs:
            s = s.replace(ch, ' ')
        return s

    def abnormal_entry(line):
        choices = {k: line[k] for k in string.ascii_uppercase if k in line and not pd.isna(line[k])}
        has_label = False
        for k in choices:
            s = proc_str(choices[k]).split()
            hit_words = [x for x in s if x in choices]
            hit_words = set(hit_words)
            if len(hit_words) > 1:
                return True
            if choices[k] in string.ascii_uppercase:
                has_label = True
        return has_label

    assert inp.endswith('.tsv')
    data = load(inp)
    OFFSET = 1e6
    while max(data['index']) >= OFFSET:
        OFFSET *= 10
    n_opt = 2
    for i, ch in enumerate(string.ascii_uppercase):
        if ch in data:
            n_opt = ord(ch) - ord('A') + 1
        else:
            for j in range(i + 1, 26):
                assert string.ascii_uppercase[j] not in data
    groups = defaultdict(list)
    for i in range(len(data)):
        item = data.iloc[i]
        this_n_opt = 0
        for j, ch in enumerate(string.ascii_uppercase[:n_opt]):
            if not pd.isna(item[ch]):
                this_n_opt = j + 1
            else:
                for k in range(j + 1, n_opt):
                    assert pd.isna(item[string.ascii_uppercase[k]]), (k, item)
        assert this_n_opt >= 2 or this_n_opt == 0
        flag = abnormal_entry(item)
        if flag or this_n_opt == 0:
            groups['abnormal'].append(item)
        elif len(item['answer']) > 1 or item['answer'] not in string.ascii_uppercase[:this_n_opt]:
            groups['abnormal'].append(item)
        else:
            groups[this_n_opt].append(item)
    for k in groups:
        groups[k] = pd.concat(groups[k], axis=1).T
        print(f'{k if k == "abnormal" else str(k) + "-choice"} records: {len(groups[k])}')

    data_all = []

    for k in groups:
        if k == 'abnormal':
            warnings.warn(
                f"{len(groups['abnormal'])} abnormal entries detected. The problems can be: "
                "1. Choice labels found in some choice contents; 2. No choices found for this question; "
                "3. The answer is not a valid choice. Will not apply circular to those samples."
            )
            abdata = groups['abnormal']
            abdata['g_index'] = abdata['index']
            data_all.append(abdata)
        else:
            cir_data = []
            assert isinstance(k, int) and k >= 2
            labels = string.ascii_uppercase[:k]
            rotates = [labels]
            dq = deque(labels)
            for i in range(k - 1):
                dq.rotate(1)
                rotates.append(list(dq))
            for i, rot in enumerate(rotates):
                if i == 0:
                    data = groups[k].copy()
                    data['g_index'] = data['index']
                    cir_data.append(data)
                else:
                    try:
                        data = groups[k].copy()
                        data['index'] = [int(x + OFFSET * i) for x in data['index']]
                        data['g_index'] = [int(x % OFFSET) for x in data['index']]
                        data['image'] = data['g_index']
                        c_map = {k: v for k, v in zip(rotates[0], rot)}
                        data['answer'] = [c_map[x] for x in data['answer']]
                        for s, t in c_map.items():
                            data[t] = groups[k][s]
                        cir_data.append(data)
                    except:
                        print(set(data['answer']))
                        raise NotImplementedError
            data_all.append(pd.concat(cir_data))
    data_all = pd.concat(data_all)
    data_all['index'] = [int(x) for x in data_all['index']]
    data_all['g_index'] = [int(x) for x in data_all['g_index']]

    tgt_file = inp.replace('.tsv', '_circular.tsv')
    dump(data_all, tgt_file)
    print(f'Processed data are saved to {tgt_file}: {len(load(inp))} raw records, {len(data_all)} circularized records.')  # noqa: E501
    assert osp.exists(tgt_file)
    print(f'The MD5 for the circularized data is {md5(tgt_file)}')


PTH = osp.realpath(__file__)
IMAGE_PTH = osp.join(osp.dirname(PTH), '../assets/apple.jpg')

msg1 = [
    IMAGE_PTH,
    'What is in this image?'
]
msg2 = [
    dict(type='image', value=IMAGE_PTH),
    dict(type='text', value='What is in this image?')
]
msg3 = [
    IMAGE_PTH,
    IMAGE_PTH,
    'How many apples are there in these images?'
]
msg4 = [
    dict(type='image', value=IMAGE_PTH),
    dict(type='image', value=IMAGE_PTH),
    dict(type='text', value='How many apples are there in these images?')
]


def CHECK(val):
    if val in supported_VLM:
        model = supported_VLM[val]()
        print(f'Model: {val}')
        for i, msg in enumerate([msg1, msg2, msg3, msg4]):
            if i > 1 and not model.INTERLEAVE:
                continue
            res = model.generate(msg)
            print(f'Test {i + 1}: {res}')
    elif val in models:
        model_list = models[val]
        for m in model_list:
            CHECK(m)


def LOCALIZE(fname, new_fname=None):
    if new_fname is None:
        new_fname = fname.replace('.tsv', '_local.tsv')

    base_name = osp.basename(fname)
    dname = osp.splitext(base_name)[0]

    data = load(fname)
    data_new = localize_df(data, dname)
    dump(data_new, new_fname)
    print(f'The localized version of data file is {new_fname}')
    return new_fname


def RUN(lvl, model):
    import torch
    NGPU = torch.cuda.device_count()
    SCRIPT = osp.join(osp.dirname(__file__), '../run.py')
    logger = get_logger('Run Missing')

    def get_env(name):
        assert name in ['433', '437', '440', 'latest']
        load_env()
        env_key = f'ENV_{name}'
        return os.environ.get(env_key, None)

    missing = MISSING(lvl)
    if model == 'all':
        pass
    elif model == 'api':
        missing = [x for x in missing if x[0] in models['api']]
    elif model == 'hf':
        missing = [x for x in missing if x[0] not in models['api']]
    elif model in models:
        missing = [x for x in missing if x[0] in models[missing]]
    elif model in supported_VLM:
        missing = [x for x in missing if x[0] == model]
    else:
        warnings.warn(f'Invalid model {model}.')

    missing.sort(key=lambda x: x[0])
    groups = defaultdict(list)
    for m, D in missing:
        groups[m].append(D)
    for m in groups:
        if m in SKIP_MODELS:
            continue
        for dataset in groups[m]:
            logger.info(f'Running {m} on {dataset}')
            exe = 'python' if m in LARGE_MODELS or m in models['api'] else 'torchrun'
            if m not in models['api']:
                env = None
                env = 'latest' if m in models['latest'] else env
                env = '433' if m in models['4.33.0'] else env
                env = '437' if m in models['4.37.0'] else env
                env = '440' if m in models['4.40.0'] else env
                if env is None:
                    # Not found, default to latest
                    env = 'latest'
                    logger.warning(
                        f"Model {m} does not have a specific environment configuration. Defaulting to 'latest'.")
                pth = get_env(env)
                if pth is not None:
                    exe = osp.join(pth, 'bin', exe)
                else:
                    logger.warning(f'Cannot find the env path {env} for model {m}')
            if exe.endswith('torchrun'):
                cmd = f'{exe} --nproc-per-node={NGPU} {SCRIPT} --model {m} --data {dataset}'
            elif exe.endswith('python'):
                cmd = f'{exe} {SCRIPT} --model {m} --data {dataset}'
            os.system(cmd)


def EVAL(dataset_name, data_file, **kwargs):
    from vlmeval.dataset import build_dataset
    logger = get_logger('VLMEvalKit Tool-Eval')
    dataset = build_dataset(dataset_name)
    # Set the judge kwargs first before evaluation or dumping
    judge_kwargs = {'nproc': 4, 'verbose': True}
    if 'model' not in kwargs:
        if dataset.TYPE in ['MCQ', 'Y/N']:
            judge_kwargs['model'] = 'chatgpt-0125'

    else:
        judge_kwargs['model'] = kwargs['model']
    judge_kwargs['nproc'] = kwargs.get('nproc', 4)
    eval_results = dataset.evaluate(data_file, **judge_kwargs)
    if eval_results is not None:
        assert isinstance(eval_results, dict) or isinstance(eval_results, pd.DataFrame)
        logger.info('Evaluation Results:')
    if isinstance(eval_results, dict):
        logger.info('\n' + json.dumps(eval_results, indent=4))
    elif isinstance(eval_results, pd.DataFrame):
        logger.info('\n')
        logger.info(tabulate(eval_results.T) if len(eval_results) < len(eval_results.columns) else eval_results)
    return eval_results


def parse_args_eval():
    parser = argparse.ArgumentParser()
    # Essential Args, Setting the Names of Datasets and Models
    parser.add_argument('cmd', type=str)
    parser.add_argument('data_file', type=str)
    parser.add_argument('--judge', type=str, default=None)
    parser.add_argument('--api-nproc', type=int, default=4)
    parser.add_argument('--retry', type=int, default=None)
    args = parser.parse_args()
    return args


def parse_args_scan():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, nargs='+')
    parser.add_argument('--data', type=str, nargs='+')
    parser.add_argument('--root', type=str, default=None)
    args, unknownargs = parser.parse_known_args()
    return args, unknownargs


def parse_args_sync():
    parser = argparse.ArgumentParser()
    parser.add_argument('--src', type=str)
    parser.add_argument('--tgt', type=str)
    parser.add_argument('--data', type=str, nargs='+')
    args, unknownargs = parser.parse_known_args()
    return args, unknownargs


def MERGE_PKL(pkl_dir, world_size=1):
    prefs = []
    for ws in list(range(1, 9)):
        prefs.extend([f'{i}{ws}_' for i in range(ws)])
    prefs = set(prefs)
    files = os.listdir(pkl_dir)
    files = [x for x in files if x[:3] in prefs]
    # Merge the files
    res_all = defaultdict(dict)
    for f in files:
        full_path = osp.join(pkl_dir, f)
        key = f[3:]
        res_all[key].update(load(full_path))
        os.remove(full_path)

    dump_prefs = [f'{i}{world_size}_' for i in range(world_size)]
    for k in res_all:
        for pf in dump_prefs:
            dump(res_all[k], f'{pkl_dir}/{pf}{k}')
        print(f'Merged {len(res_all[k])} records into {pkl_dir}/{dump_prefs[0]}{k}')


def SCAN_ONE(root, model, dataset):
    from termcolor import colored
    FAIL_MSG = 'Failed to obtain answer via API.'
    root = osp.join(root, model)
    fname = f'{model}_{dataset}.xlsx'
    pth = osp.join(root, fname)
    if osp.exists(pth):
        data = load(pth)
        # Detect Failure
        assert 'prediction' in data
        data['prediction'] = [str(x) for x in data['prediction']]
        fail = [FAIL_MSG in x for x in data['prediction']]
        if sum(fail):
            nfail = sum(fail)
            ntot = len(fail)
            print(colored(f'Model {model} x Dataset {dataset} Inference: {nfail} out of {ntot} failed. {nfail / ntot * 100: .2f}%. ', 'light_red'))  # noqa: E501

        eval_files = ls(root, match=f'{model}_{dataset}_')
        eval_files = [x for x in eval_files if listinstr([f'{dataset}_openai', f'{dataset}_gpt'], x) and x.endswith('.xlsx')]  # noqa: E501

        if len(eval_files) == 0:
            return

        for eval_file in eval_files:
            data = load(eval_file)

            if 'log' in data:
                bad = [x for x in data['log'] if FAIL_MSG in str(x)]
                if len(bad):
                    print(f'Evaluation ({eval_file}): {len(bad)} out of {len(data)} failed.')
    else:
        print(colored(f'Model {model} x Dataset {dataset} Inference Result Missing! ', 'red'))


def SCAN(root, models, datasets):
    for m in models:
        if not osp.exists(osp.join(root, m)):
            warnings.warn(f'Model {m} not found in {root}')
            continue
        cur_datasets = []
        if len(datasets) == 0:
            for d in SUPPORTED_DATASETS:
                if osp.exists(osp.join(root, m, f'{m}_{d}.xlsx')):
                    cur_datasets.append(d)
        else:
            cur_datasets = datasets
        cur_datasets = list(set(cur_datasets))
        cur_datasets.sort()
        for d in cur_datasets:
            SCAN_ONE(root, m, d)
        print(colored(f'Finished scanning datasets {cur_datasets} for model {m}.', 'green'))


def cli():
    logger = get_logger('VLMEvalKit Tools')
    args = sys.argv[1:]
    if not args:  # no arguments passed
        return

    if args[0].lower() == 'dlist':
        assert len(args) >= 2
        res = []
        for arg in args[1:]:
            lst = DLIST(arg)
            res.extend(lst)
        print(' '.join(res))
    elif args[0].lower() == 'mlist':
        assert len(args) >= 2
        size = 'all'
        if len(args) > 2:
            size = args[2].lower()
        lst = MLIST(args[1], size)
        print('\n'.join(lst))
    elif args[0].lower() == 'missing':
        assert len(args) >= 2
        missing_list = MISSING(args[1])
        logger = get_logger('Find Missing')
        logger.info(colored(f'Level {args[1]} Missing Results: ', 'red'))
        lines = []
        for m, D in missing_list:
            line = f'Model {m}, Dataset {D}'
            logger.info(colored(line, 'red'))
            lines.append(line)
        mwlines(lines, f'{args[1]}_missing.txt')
    elif args[0].lower() == 'circular':
        assert len(args) >= 2
        CIRCULAR(args[1])
    elif args[0].lower() == 'localize':
        assert len(args) >= 2
        LOCALIZE(args[1])
    elif args[0].lower() == 'check':
        assert len(args) >= 2
        model_list = args[1:]
        for m in model_list:
            CHECK(m)
    elif args[0].lower() == 'run':
        assert len(args) >= 2
        lvl = args[1]
        if len(args) == 2:
            model = 'all'
            RUN(lvl, model)
        else:
            for model in args[2:]:
                RUN(lvl, model)
    elif args[0].lower() == 'eval':
        args = parse_args_eval()
        data_file = args.data_file

        def extract_dataset(file_name):
            fname = osp.splitext(file_name)[0].split('/')[-1]
            parts = fname.split('_')
            for i in range(len(parts)):
                if '_'.join(parts[i:]) in SUPPORTED_DATASETS:
                    return '_'.join(parts[i:])
            return None

        dataset = extract_dataset(data_file)
        assert dataset is not None, f'Cannot infer dataset name from {data_file}'
        kwargs = {'nproc': args.api_nproc}
        if args.judge is not None:
            kwargs['model'] = args.judge
        if args.retry is not None:
            kwargs['retry'] = args.retry
        EVAL(dataset_name=dataset, data_file=data_file, **kwargs)
    elif args[0].lower() == 'merge_pkl':
        assert len(args) == 3
        args[2] = int(args[2])
        assert args[2] in [1, 2, 4, 8]
        MERGE_PKL(args[1], args[2])
    elif args[0].lower() == 'scan':
        args, unknownargs = parse_args_scan()
        # The default value is only for the maintainer usage
        root = args.root if args.root is not None else os.getcwd()
        models = []
        for m in args.model:
            if osp.exists(m) and m.endswith('.txt'):
                lines = mrlines(m)
                models.extend([x.split()[0] for x in lines if len(x.split()) >= 1])
            else:
                models.append(m)
        assert len(models)
        datasets = args.data
        SCAN(root, models, datasets if datasets is not None else [])
    else:
        logger.error('WARNING: command error!')
        return
