import argparse
import fnmatch
from typing import Dict

from mmengine.config import Config, ConfigDict

from opencompass.openicl.icl_inferencer import (AgentInferencer,
                                                ChatInferencer, CLPInferencer,
                                                GenInferencer, LLInferencer,
                                                PPLInferencer,
                                                PPLOnlyInferencer)
from opencompass.registry import ICL_PROMPT_TEMPLATES, ICL_RETRIEVERS
from opencompass.utils import (Menu, build_dataset_from_cfg,
                               build_model_from_cfg, dataset_abbr_from_cfg,
                               model_abbr_from_cfg)


def parse_args():
    parser = argparse.ArgumentParser(
        description='View generated prompts based on datasets (and models)')
    parser.add_argument('config', help='Train config file path')
    parser.add_argument('-n', '--non-interactive', action='store_true')
    parser.add_argument('-a', '--all', action='store_true')
    parser.add_argument('-p',
                        '--pattern',
                        type=str,
                        help='To match the dataset abbr.')
    parser.add_argument('-c',
                        '--count',
                        type=int,
                        default=1,
                        help='Number of prompts to print')
    args = parser.parse_args()
    return args


def parse_model_cfg(model_cfg: ConfigDict) -> Dict[str, ConfigDict]:
    model2cfg = {}
    for model in model_cfg:
        model2cfg[model_abbr_from_cfg(model)] = model
    return model2cfg


def parse_dataset_cfg(dataset_cfg: ConfigDict) -> Dict[str, ConfigDict]:
    dataset2cfg = {}
    for dataset in dataset_cfg:
        dataset2cfg[dataset_abbr_from_cfg(dataset)] = dataset
    return dataset2cfg


def print_prompts(model_cfg, dataset_cfg, count=1):
    # TODO: A really dirty method that copies code from PPLInferencer and
    # GenInferencer. In the future, the prompt extraction code should be
    # extracted and generalized as a static method in these Inferencers
    # and reused here.
    if model_cfg:
        max_seq_len = model_cfg.max_seq_len
        if not model_cfg['type'].is_api:
            model_cfg['tokenizer_only'] = True
        model = build_model_from_cfg(model_cfg)
    else:
        max_seq_len = None
        model = None

    infer_cfg = dataset_cfg.get('infer_cfg')

    dataset = build_dataset_from_cfg(dataset_cfg)

    ice_template = None
    if hasattr(infer_cfg, 'ice_template'):
        ice_template = ICL_PROMPT_TEMPLATES.build(infer_cfg['ice_template'])

    prompt_template = None
    if hasattr(infer_cfg, 'prompt_template'):
        prompt_template = ICL_PROMPT_TEMPLATES.build(
            infer_cfg['prompt_template'])

    infer_cfg['retriever']['dataset'] = dataset
    retriever = ICL_RETRIEVERS.build(infer_cfg['retriever'])

    ice_idx_list = retriever.retrieve()

    supported_inferencer = [
        AgentInferencer, PPLInferencer, GenInferencer, CLPInferencer,
        PPLOnlyInferencer, ChatInferencer, LLInferencer
    ]
    if infer_cfg.inferencer.type not in supported_inferencer:
        print(f'Only {supported_inferencer} are supported')
        return

    for idx in range(min(count, len(ice_idx_list))):
        if issubclass(infer_cfg.inferencer.type,
                      (PPLInferencer, LLInferencer)):
            labels = retriever.get_labels(ice_template=ice_template,
                                          prompt_template=prompt_template)
            ice = retriever.generate_ice(ice_idx_list[idx],
                                         ice_template=ice_template)
            print('-' * 100)
            print('ICE Template:')
            print('-' * 100)
            print(ice)
            print('-' * 100)
            for label in labels:
                prompt = retriever.generate_label_prompt(
                    idx,
                    ice,
                    label,
                    ice_template=ice_template,
                    prompt_template=prompt_template,
                    remain_sep=None)
                if max_seq_len is not None:
                    prompt_token_num = model.get_token_len_from_template(
                        prompt)
                    while len(ice_idx_list[idx]
                              ) > 0 and prompt_token_num > max_seq_len:
                        num_ice = len(ice_idx_list[idx])
                        print(f'Truncating ice {num_ice} -> {num_ice - 1}',
                              f'Number of tokens: {prompt_token_num} -> ...')
                        ice_idx_list[idx] = ice_idx_list[idx][:-1]
                        ice = retriever.generate_ice(ice_idx_list[idx],
                                                     ice_template=ice_template)
                        prompt = retriever.generate_label_prompt(
                            idx,
                            ice,
                            label,
                            ice_template=ice_template,
                            prompt_template=prompt_template)
                        prompt_token_num = model.get_token_len_from_template(
                            prompt)
                    print(f'Number of tokens: {prompt_token_num}')
                if model is not None:
                    prompt = model.parse_template(prompt, mode='ppl')
                print('-' * 100)
                print(f'Label: {label}')
                print('Sample prompt:')
                print('-' * 100)
                print(prompt)
                print('-' * 100)
        else:
            ice_idx = ice_idx_list[idx]
            ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
            prompt = retriever.generate_prompt_for_generate_task(
                idx,
                ice,
                gen_field_replace_token=infer_cfg.inferencer.get(
                    'gen_field_replace_token', ''),
                ice_template=ice_template,
                prompt_template=prompt_template)
            if max_seq_len is not None:
                prompt_token_num = model.get_token_len_from_template(prompt)
                while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
                    num_ice = len(ice_idx)
                    print(f'Truncating ice {num_ice} -> {num_ice - 1}',
                          f'Number of tokens: {prompt_token_num} -> ...')
                    ice_idx = ice_idx[:-1]
                    ice = retriever.generate_ice(ice_idx,
                                                 ice_template=ice_template)
                    prompt = retriever.generate_prompt_for_generate_task(
                        idx,
                        ice,
                        gen_field_replace_token=infer_cfg.inferencer.get(
                            'gen_field_replace_token', ''),
                        ice_template=ice_template,
                        prompt_template=prompt_template)
                    prompt_token_num = model.get_token_len_from_template(
                        prompt)
                print(f'Number of tokens:  {prompt_token_num}')
            if model is not None:
                prompt = model.parse_template(prompt, mode='gen')
            print('-' * 100)
            print('Sample prompt:')
            print('-' * 100)
            print(prompt)
            print('-' * 100)


def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    # cfg.models =
    model2cfg = parse_model_cfg(cfg.models) if 'models' in cfg else {
        'None': None
    }
    if 'datasets' in cfg:
        dataset2cfg = parse_dataset_cfg(cfg.datasets)
    else:
        dataset2cfg = {}
        for key in cfg.keys():
            if key.endswith('_datasets'):
                dataset2cfg.update(parse_dataset_cfg(cfg[key]))

    if args.pattern is not None:
        matches = fnmatch.filter(dataset2cfg, args.pattern)
        if len(matches) == 0:
            raise ValueError(
                'No dataset match the pattern. Please select from: \n' +
                '\n'.join(dataset2cfg.keys()))
        dataset2cfg = {k: dataset2cfg[k] for k in matches}

    if not args.all:
        if not args.non_interactive:
            model, dataset = Menu(
                [list(model2cfg.keys()),
                 list(dataset2cfg.keys())], [
                     f'Please make a selection of {s}:'
                     for s in ['model', 'dataset']
                 ]).run()
        else:
            model = list(model2cfg.keys())[0]
            dataset = list(dataset2cfg.keys())[0]
        model_cfg = model2cfg[model]
        dataset_cfg = dataset2cfg[dataset]
        print_prompts(model_cfg, dataset_cfg, args.count)
    else:
        for model_abbr, model_cfg in model2cfg.items():
            for dataset_abbr, dataset_cfg in dataset2cfg.items():
                print('=' * 64, '[BEGIN]', '=' * 64)
                print(f'[MODEL]: {model_abbr}')
                print(f'[DATASET]: {dataset_abbr}')
                print('---')
                print_prompts(model_cfg, dataset_cfg, args.count)
                print('=' * 65, '[END]', '=' * 65)
                print()


if __name__ == '__main__':
    main()
