import torch
import torch.distributed as dist

from vlmeval.config import supported_VLM
from vlmeval.dataset import build_dataset
from vlmeval.inference import infer_data_job
from vlmeval.inference_video import infer_data_job_video
from vlmeval.inference_mt import infer_data_job_mt
from vlmeval.smp import *
from vlmeval.utils.result_transfer import MMMU_result_transfer, MMTBench_result_transfer
import re

def build_model_from_config(cfg):
    import vlmeval.api
    import vlmeval.vlm
    config = cp.deepcopy(cfg)
    assert 'class' in config
    cls_name = config.pop('class')
    if hasattr(vlmeval.api, cls_name):
        return getattr(vlmeval.api, cls_name)(**config)
    elif hasattr(vlmeval.vlm, cls_name):
        return getattr(vlmeval.vlm, cls_name)(**config)
    else:
        raise ValueError(f'Class {cls_name} is not supported in `vlmeval.api` or `vlmeval.vlm`')


def build_dataset_from_config(cfg):
    import vlmeval.dataset
    config = cp.deepcopy(cfg)
    assert 'class' in config
    cls_name = config.pop('class')
    if hasattr(vlmeval.dataset, cls_name):
        return getattr(vlmeval.dataset, cls_name)(**config)
    else:
        raise ValueError(f'Class {cls_name} is not supported in `vlmeval.dataset`')


def parse_args():
    help_msg = """\
You can launch the evaluation by setting either --data and --model or --config.

--data and --model:
    Each Arg should be a list of strings, specifying the names of datasets and models.
    To find all supported model names, please refer to the `vlmeval/config.py` of check the output of the command \
        `vlmutil mlist all` in the terminal (you should first have vlmeval installed).
    To find all supported dataset names, please refer to the `vlmeval/dataset/__init__.py` file. The python script \
        to print all supported dataset names is as follows:
        ```python
        from vlmeval.dataset import SUPPORTED_DATASETS
        print(SUPPORTED_DATASETS)
        ```
        or you can check the output of the command `vlmutil dlist all` in the terminal.

--config:
    Launch the evaluation by specifying the path to the config json file. Sample Json Content:
    ```json
    {
        "model": {
            "GPT4o_20240806_T00_HIGH": {
                "class": "GPT4V",
                "model": "gpt-4o-2024-08-06",
                "temperature": 0,
                "img_detail": "high"
            },
            "GPT4o_20240806_T10_Low": {
                "class": "GPT4V",
                "model": "gpt-4o-2024-08-06",
                "temperature": 1.0,
                "img_detail": "low"
            }
        },
        "data": {
            "MME-RealWorld-Lite": {
                "class": "MMERealWorld",
                "dataset": "MME-RealWorld-Lite"
            },
            "MMBench_DEV_EN_V11": {
                "class": "ImageMCQDataset",
                "dataset": "MMBench_DEV_EN_V11"
            }
        }
    }
    ```
    Currently, only `model` and `data` are supported fields. The content of each field is a dictionary.
    For `model`, the key is the name of the model, and the value is a dictionary containing the following keys:
    - `class`: The class name of the model, which should be a class in `vlmeval.vlm` or `vlmeval.api`.
    - Other keys are specific to the model, please refer to the corresponding class.
    For `data`, the key is the name of the dataset (should be the same as the `dataset` field in most cases, \
        except for video datasets), and the value is a dictionary containing the following keys:
    - `class`: The class name of the dataset, which should be a class in `vlmeval.dataset`.
    - `dataset`: The name of the dataset, which should be a string that is accepted by the `dataset` argument of the \
        corresponding class.
    - Other keys are specific to the dataset, please refer to the corresponding class.

    The keys in the `model` and `data` fields will be used for naming the prediction files and evaluation results.
    When launching with `--config`, args for video datasets, such as `--nframe`, `--pack`, `--use-subtitle`, `--fps`, \
        and args for API VLMs, such as `--retry`, `--verbose`, will be ignored.
"""
    parser = argparse.ArgumentParser(description=help_msg, formatter_class=argparse.RawTextHelpFormatter)
    # Essential Args, Setting the Names of Datasets and Models
    parser.add_argument('--data', type=str, nargs='+', help='Names of Datasets')
    parser.add_argument('--model', type=str, nargs='+', help='Names of Models')
    parser.add_argument('--config', type=str, help='Path to the Config Json File')
    # Args that only apply to Video Dataset
    parser.add_argument('--nframe', type=int, default=8)
    parser.add_argument('--pack', action='store_true')
    parser.add_argument('--use-subtitle', action='store_true')
    parser.add_argument('--fps', type=float, default=-1)
    # Work Dir
    parser.add_argument('--work-dir', type=str, default='./outputs_shiqi', help='select the output directory')
    # Infer + Eval or Infer Only
    parser.add_argument('--mode', type=str, default='all', choices=['all', 'infer'])
    # API Kwargs, Apply to API VLMs and Judge API LLMs
    parser.add_argument('--nproc', type=int, default=4, help='Parallel API calling')
    parser.add_argument('--retry', type=int, default=None, help='retry numbers for API VLMs')
    # Explicitly Set the Judge Model
    parser.add_argument('--judge', type=str, default=None)
    # Logging Utils
    parser.add_argument('--verbose', action='store_true')
    # Configuration for Resume
    # Ignore: will not rerun failed VLM inference
    parser.add_argument('--ignore', action='store_true', help='Ignore failed indices. ')
    # Reuse: will reuse the existing prediction files
    parser.add_argument('--reuse', action='store_true')
    parser.add_argument('--Intervening_layer', type=int, default=None, help='the layer to intervene')
    parser.add_argument('--Intervening_module', type=str, default=None, help='the module to intervene')

    args = parser.parse_args()
    return args


def main():
    logger = get_logger('RUN')
    rank, world_size = get_rank_and_world_size()
    args = parse_args()
    use_config, cfg = False, None
    if args.config is not None:
        assert args.data is None and args.model is None, '--data and --model should not be set when using --config'
        use_config, cfg = True, load(args.config)
        args.model = list(cfg['model'].keys())
        args.data = list(cfg['data'].keys())
    else:
        assert len(args.data), '--data should be a list of data files'

    if rank == 0:
        if not args.reuse:
            logger.warning('--reuse is not set, will not reuse previous (before one day) temporary files')
        else:
            logger.warning('--reuse is set, will reuse the latest prediction & temporary pickle files')

    if 'MMEVAL_ROOT' in os.environ:
        args.work_dir = os.environ['MMEVAL_ROOT']

    if not use_config:
        for k, v in supported_VLM.items():
            if hasattr(v, 'keywords') and 'retry' in v.keywords and args.retry is not None:
                v.keywords['retry'] = args.retry
                supported_VLM[k] = v
            if hasattr(v, 'keywords') and 'verbose' in v.keywords and args.verbose is not None:
                v.keywords['verbose'] = args.verbose
                supported_VLM[k] = v

    if world_size > 1:
        local_rank = os.environ.get('LOCAL_RANK', 0)
        torch.cuda.set_device(int(local_rank))
        dist.init_process_group(backend='nccl', timeout=datetime.timedelta(seconds=3600))

    for _, model_name in enumerate(args.model):
        model = None
        date, commit_id = timestr('day'), githash(digits=8)
        eval_id = f"T{date}_G{commit_id}"

        pred_root = osp.join(args.work_dir, model_name, eval_id)

        pred_root_meta = osp.join(args.work_dir, model_name)
        
        if args.merge_model:
            match = re.search(r'merged_model_(.+?)(?=\.pth)', args.merge_model)
            if match:
                merge_model_suffix = match.group(1) 
                
                pred_root = osp.join(args.work_dir, model_name + merge_model_suffix, eval_id)
                pred_root_meta = osp.join(args.work_dir, model_name+ merge_model_suffix)
                os.makedirs(pred_root_meta, exist_ok=True)
                prev_pred_roots = ls(osp.join(args.work_dir, model_name+ merge_model_suffix), mode='dir')
            else:
                raise ValueError(f"Invalid merge_model path: {args.merge_model}")
        else:
            os.makedirs(pred_root_meta, exist_ok=True)
            prev_pred_roots = ls(osp.join(args.work_dir, model_name), mode='dir')
        

        
        
        if len(prev_pred_roots) and args.reuse:
            prev_pred_roots.sort()

        if not osp.exists(pred_root):
            os.makedirs(pred_root, exist_ok=True)

        if use_config:
            model = build_model_from_config(cfg['model'][model_name])
        

        for _, dataset_name in enumerate(args.data):
            try:
                if args.Intervening_layer is not None and args.Intervening_module is not None:
                    if isinstance(args.Intervening_layer, list):
                        layer_str = "_".join(map(str, args.cut_laIntervening_layeryer))
                        result_file_base = f'{model_name}_{dataset_name}_cut{layer_str}_{args.Intervening_module}.xlsx'
                    else:
                        result_file_base = f'{model_name}_{dataset_name}_cut{args.Intervening_layer}_{args.cut_moIntervening_moduledule}.xlsx'
                else:
                    result_file_base = f'{model_name}_{dataset_name}.xlsx'

                if use_config:
                    if world_size > 1:
                        if rank == 0:
                            dataset = build_dataset_from_config(cfg['data'][dataset_name])
                        dist.barrier()
                    dataset = build_dataset_from_config(cfg['data'][dataset_name])
                    if dataset is None:
                        logger.error(f'Dataset {dataset_name} is not valid, will be skipped. ')
                        continue
                else:
                    dataset_kwargs = {}
                    if dataset_name in ['MMLongBench_DOC', 'DUDE', 'DUDE_MINI', 'SLIDEVQA', 'SLIDEVQA_MINI']:
                        dataset_kwargs['model'] = model_name
                    if dataset_name == 'MMBench-Video':
                        dataset_kwargs['pack'] = args.pack
                    if dataset_name == 'Video-MME':
                        dataset_kwargs['use_subtitle'] = args.use_subtitle

                    # If distributed, first build the dataset on the main process for doing preparation works
                    if world_size > 1:
                        if rank == 0:
                            dataset = build_dataset(dataset_name, **dataset_kwargs)
                        dist.barrier()

                    dataset = build_dataset(dataset_name, **dataset_kwargs)
                    if dataset is None:
                        logger.error(f'Dataset {dataset_name} is not valid, will be skipped. ')
                        continue
                    # Handling Video Datasets. For Video Dataset, set the fps for priority
                    if args.fps > 0:
                        if dataset_name == 'MVBench':
                            raise ValueError('MVBench does not support fps setting, please transfer to MVBench_MP4!')
                        args.nframe = 0
                    if dataset_name in ['MMBench-Video']:
                        packstr = 'pack' if args.pack else 'nopack'
                        if args.nframe > 0:
                            result_file_base = f'{model_name}_{dataset_name}_{args.nframe}frame_{packstr}.xlsx'
                        else:
                            result_file_base = f'{model_name}_{dataset_name}_{args.fps}fps_{packstr}.xlsx'
                    elif dataset.MODALITY == 'VIDEO':
                        if args.pack:
                            logger.info(f'{dataset_name} not support Pack Mode, directly change to unpack')
                            args.pack = False
                        packstr = 'pack' if args.pack else 'nopack'
                        if args.nframe > 0:
                            result_file_base = f'{model_name}_{dataset_name}_{args.nframe}frame_{packstr}.xlsx'
                        else:
                            result_file_base = f'{model_name}_{dataset_name}_{args.fps}fps_{packstr}.xlsx'
                        if dataset_name in ['Video-MME', 'LongVideoBench']:
                            subtitlestr = 'subs' if args.use_subtitle else 'nosubs'
                            result_file_base = result_file_base.replace('.xlsx', f'_{subtitlestr}.xlsx')

                # Handling Multi-Turn Dataset
                if dataset.TYPE == 'MT':
                    result_file_base = result_file_base.replace('.xlsx', '.tsv')

                result_file = osp.join(pred_root, result_file_base)

                # Reuse the previous prediction file if exists
                if rank == 0 and len(prev_pred_roots):
                    prev_result_file = None
                    prev_pkl_file_list = []
                    for root in prev_pred_roots[::-1]:
                        if osp.exists(osp.join(root, result_file_base)):
                            prev_result_file = osp.join(root, result_file_base)
                            break
                        elif commit_id in root and len(ls(root)) and root != pred_root:
                            temp_files = ls(root, match=[dataset_name, '.pkl'])
                            if len(temp_files):
                                prev_pkl_file_list.extend(temp_files)
                                break
                    if not args.reuse:
                        prev_result_file = None
                        prev_pkl_file_list = []
                    if prev_result_file is not None:
                        logger.warning(
                            f'--reuse is set, will reuse the prediction file {prev_result_file}.')
                        if prev_result_file != result_file:
                            shutil.copy(prev_result_file, result_file)
                    elif len(prev_pkl_file_list):
                        for fname in prev_pkl_file_list:
                            target_path = osp.join(pred_root, osp.basename(fname))
                            if not osp.exists(target_path):
                                shutil.copy(fname, target_path)
                                logger.info(f'--reuse is set, will reuse the prediction pickle file {fname}.')
                            else:
                                logger.warning(f'File already exists: {target_path}')

                if world_size > 1:
                    dist.barrier()

                if model is None:
                    model = model_name  # which is only a name

                # Perform the Inference
                if dataset.MODALITY == 'VIDEO':
                    model = infer_data_job_video(
                        model,
                        work_dir=pred_root,
                        model_name=model_name,
                        dataset=dataset,
                        nframe=args.nframe,
                        pack=args.pack,
                        verbose=args.verbose,
                        subtitle=args.use_subtitle,
                        api_nproc=args.nproc,
                        fps=args.fps)
                elif dataset.TYPE == 'MT':
                    model = infer_data_job_mt(
                        model,
                        work_dir=pred_root,
                        model_name=model_name,
                        dataset=dataset,
                        verbose=args.verbose,
                        api_nproc=args.nproc,
                        ignore_failed=args.ignore)
                else:
                    model = infer_data_job(
                        model,
                        work_dir=pred_root,
                        model_name=model_name,
                        dataset=dataset,
                        verbose=args.verbose,
                        api_nproc=args.nproc,
                        ignore_failed=args.ignore,
                        merge_model=args.merge_model,
                        Intervening_layer=args.Intervening_layer,
                        Intervening_module=args.Intervening_module)

                # Set the judge kwargs first before evaluation or dumping
                
                judge_kwargs = {
                    'nproc': args.nproc,
                    'verbose': args.verbose,
                    'retry': args.retry if args.retry is not None else 3
                }

                if args.retry is not None:
                    judge_kwargs['retry'] = args.retry
                if args.judge is not None:
                    judge_kwargs['model'] = args.judge
                else:
                    if dataset.TYPE in ['MCQ', 'Y/N'] or listinstr(['MathVerse'], dataset_name):
                        judge_kwargs['model'] = 'chatgpt-0125'
                    elif listinstr(['MMVet', 'MathVista', 'LLaVABench', 'MMBench-Video', 'MathVision'],
                                   dataset_name):
                        judge_kwargs['model'] = 'gpt-4-turbo'
                    elif listinstr([
                        'MMLongBench', 'MMDU', 'DUDE', 'SLIDEVQA', 'MIA-Bench', 'WildVision'
                    ], dataset_name):
                        judge_kwargs['model'] = 'gpt-4o'
                if rank == 0:
                    logger.info(judge_kwargs)

                if world_size > 1:
                    dist.barrier()

                # Only Rank 0 handles the evaluation part
                if rank == 0:
                    # Prepare Submission Files for MMMU_TEST AND MMT-Bench_ALL
                    if dataset_name in ['MMMU_TEST']:
                        result_json = MMMU_result_transfer(result_file)
                        logger.info(f'Transfer MMMU_TEST result to json for official evaluation, '
                                    f'json file saved in {result_json}')
                        continue
                    elif 'MMT-Bench_ALL' in dataset_name:
                        submission_file = MMTBench_result_transfer(result_file, **judge_kwargs)
                        logger.info(f'Extract options from prediction of MMT-Bench FULL split for official evaluation '
                                    f'(https://eval.ai/web/challenges/challenge-page/2328/overview), '
                                    f'submission file saved in {submission_file}')
                        continue

                    # Skip the evaluation part if only infer
                    if args.mode == 'infer':
                        continue

                    # Skip the evaluation part if the dataset evaluation is not supported or annotations are missing
                    if 'MLLMGuard_DS' in dataset_name:
                        logger.info('The evaluation of MLLMGuard_DS is not supported yet. ')
                        continue
                    elif 'AesBench_TEST' == dataset_name:
                        logger.info(f'The results are saved in {result_file}. '
                                    f'Please send it to the AesBench Team via huangyipo@hotmail.com.')
                        continue
                    elif dataset_name in ['DocVQA_TEST', 'InfoVQA_TEST', 'Q-Bench1_TEST', 'A-Bench_TEST']:
                        logger.info(f'{dataset_name} is a test split without ground-truth. '
                                    'Thus only the inference part is supported for those datasets. ')
                        continue
                    elif dataset_name in [
                        'MMBench_TEST_CN', 'MMBench_TEST_EN', 'MMBench', 'MMBench_CN',
                        'MMBench_TEST_CN_V11', 'MMBench_TEST_EN_V11', 'MMBench_V11', 'MMBench_CN_V11'
                    ] and not MMBenchOfficialServer(dataset_name):
                        logger.error(
                            f'Can not evaluate {dataset_name} on non-official servers, will skip the evaluation.')
                        continue

                    # Setup the proxy for the evaluation
                    eval_proxy = os.environ.get('EVAL_PROXY', None)
                    old_proxy = os.environ.get('HTTP_PROXY', '')
                    if eval_proxy is not None:
                        proxy_set(eval_proxy)

                    # Perform the Evaluation
                    eval_results = dataset.evaluate(result_file, **judge_kwargs)

                    # 保存 cut layer 实验的结果
                    if args.Intervening_layer is not None and args.Intervening_module is not None and eval_results is not None:
                        # 从 merge_model 路径中提取权重参数
                        weights = "unknown"
                        if args.merge_model:
                            match = re.search(r'merged_model_(.+?)(?=\.pth)', args.merge_model)
                            if match:
                                weights = match.group(1)
                        
                        # 🔧 修复：使用与脚本一致的保存路径
                        cut_results_dir = "/data_all/intern05/VLM_Merging/evalkit_results/Intervening_layer_experiments/results/mean_llava"
                        os.makedirs(cut_results_dir, exist_ok=True)
                        
                        # 构建文件名：llava_{weights}_eval_{layers}_{module}_{dataset}.csv
                        model_prefix = model_name.split('_')[0] if '_' in model_name else model_name
                        cut_result_filename = f"{model_prefix}_{weights}_eval_{args.Intervening_layer}_{args.Intervening_module}_{dataset_name}.csv"
                        cut_result_path = osp.join(cut_results_dir, cut_result_filename)
                        
                        # 添加调试信息
                        logger.info(f"🔍 保存Cut Layer结果:")
                        logger.info(f"   权重参数: {weights}")
                        logger.info(f"   层: {args.Intervening_layer}")
                        logger.info(f"   模块: {args.Intervening_module}")
                        logger.info(f"   数据集: {dataset_name}")
                        logger.info(f"   保存路径: {cut_result_path}")
                        
                        try:
                            # 保存评估结果
                            if isinstance(eval_results, dict):
                                # 如果是字典，转换为 DataFrame 然后保存
                                df = pd.DataFrame([eval_results])
                                df.to_csv(cut_result_path, index=False)
                                logger.info(f'✅ Cut layer evaluation results (dict) saved to: {cut_result_path}')
                            elif isinstance(eval_results, pd.DataFrame):
                                # 如果是 DataFrame，直接保存
                                eval_results.to_csv(cut_result_path, index=False)
                                logger.info(f'✅ Cut layer evaluation results (DataFrame) saved to: {cut_result_path}')
                            else:
                                logger.warning(f"⚠️  Unknown eval_results type: {type(eval_results)}")
                            
                            # 验证文件是否真的被创建
                            if osp.exists(cut_result_path):
                                file_size = os.path.getsize(cut_result_path)
                                logger.info(f"✅ 文件验证成功: {cut_result_path} (大小: {file_size} bytes)")
                            else:
                                logger.error(f"❌ 文件创建失败: {cut_result_path}")
                            
                        except Exception as e:
                            logger.error(f'❌ Failed to save cut layer results: {e}')
                            import traceback
                            logger.error(traceback.format_exc())

                    # Display Evaluation Results in Terminal
                    if eval_results is not None:
                        assert isinstance(eval_results, dict) or isinstance(eval_results, pd.DataFrame)
                        logger.info(f'The evaluation of model {model_name} x dataset {dataset_name} has finished! ')
                        logger.info('Evaluation Results:')
                        if isinstance(eval_results, dict):
                            logger.info('\n' + json.dumps(eval_results, indent=4))
                        elif isinstance(eval_results, pd.DataFrame):
                            if len(eval_results) < len(eval_results.columns):
                                eval_results = eval_results.T
                            logger.info('\n' + tabulate(eval_results))

                    # Restore the proxy
                    if eval_proxy is not None:
                        proxy_set(old_proxy)

                    # Create the symbolic links for the prediction files
                    files = os.listdir(pred_root)
                    files = [x for x in files if f'{model_name}_{dataset_name}' in x]
                    for f in files:
                        cwd = os.getcwd()
                        file_addr = osp.join(cwd, pred_root, f)
                        link_addr = osp.join(cwd, pred_root_meta, f)
                        if osp.exists(link_addr) or osp.islink(link_addr):
                            os.remove(link_addr)
                        os.symlink(file_addr, link_addr)

            except Exception as e:
                logger.exception(f'Model {model_name} x Dataset {dataset_name} combination failed: {e}, '
                                 'skipping this combination.')
                continue

            if world_size > 1:
                dist.barrier()

    if world_size > 1:
        dist.destroy_process_group()


if __name__ == '__main__':
    load_env()
    main()
