#!/usr/bin/env python3
# Copyright (c) 2024 ByteDance. All Rights Reserved.
# GLEE Training Script.
# GLEE: General Object Foundation Model for Images and Videos at Scale (CVPR 2024)
# https://arxiv.org/abs/2312.09158


import os
import itertools
from typing import Any, Dict, List, Set
import torch
import logging
from contextlib import ExitStack

import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.engine import default_argument_parser, launch
from datetime import datetime

from detectron2.evaluation import (
    DatasetEvaluator,
    print_csv_format,
    inference_context
)

# related to Trainer
from detectron2.projects.train_net import Trainer as Trainer_detectron2
from detectron2.projects.train_net import load_config_dict_to_opt, load_opt_from_config_files, setup
from collections import OrderedDict

# related to inference_on_dataset
import datetime
import logging
import time
from collections import OrderedDict, abc

from typing import List, Union
import torch
from torch import nn

from detectron2.utils.comm import get_world_size, is_main_process, synchronize, all_gather
from detectron2.utils.logger import log_every_n_seconds

# related to omnilabeltools
from omnilabeltools import OmniLabel, OmniLabelEval, visualize_image_sample
import json
from detectron2.structures import BoxMode

import copy
import gc


def num_of_words(text):
    return len(text.split(' '))

def inference_on_dataset2(
    model, data_loader, evaluator: Union[DatasetEvaluator, List[DatasetEvaluator], None], dataset_name
):
    """
    Run model on the data_loader and evaluate the metrics with evaluator.
    Also benchmark the inference speed of `model.__call__` accurately.
    The model will be used in eval mode.

    Args:
        model (callable): a callable which takes an object from
            `data_loader` and returns some outputs.

            If it's an nn.Module, it will be temporarily set to `eval` mode.
            If you wish to evaluate a model in `training` mode instead, you can
            wrap the given model and override its behavior of `.eval()` and `.train()`.
        data_loader: an iterable object with a length.
            The elements it generates will be the inputs to the model.
        evaluator: the evaluator(s) to run. Use `None` if you only want to benchmark,
            but don't want to do any evaluation.

    Returns:
        The return value of `evaluator.evaluate()`
    """
    num_devices = get_world_size()
    logger = logging.getLogger(__name__)
    logger.info("Start inference on {} batches".format(len(data_loader)))

    total = len(data_loader)  # inference data loader must have a fixed length
    num_warmup = min(5, total - 1)
    start_time = time.perf_counter()
    total_data_time = 0
    total_compute_time = 0
    total_eval_time = 0
    with ExitStack() as stack:
        if isinstance(model, nn.Module):
            stack.enter_context(inference_context(model))
        stack.enter_context(torch.no_grad())

        start_data_time = time.perf_counter()

        predictions = []

        for idx, inputs in enumerate(data_loader):
            total_data_time += time.perf_counter() - start_data_time
            if idx == num_warmup:
                start_time = time.perf_counter()
                total_data_time = 0
                total_compute_time = 0
                total_eval_time = 0

            start_compute_time = time.perf_counter()

            ### implement different inference strategy
            ## split original input to partial inputs according to different inference strategy


            # inference like grounding + detection
            all_des = inputs[0]['inference_obj_descriptions']
            all_des_ids = inputs[0]['description_ids']

            des_idx_start = 0
            chunk_size = 100

            while des_idx_start < len(all_des):
                if num_of_words(all_des[des_idx_start]) > 2:
                    #current_task = "grounding"
                    current_input = copy.deepcopy(inputs)
                    #current_input[0]['task'] = current_task
                    current_input[0]['inference_obj_descriptions'] = [all_des[des_idx_start]]
                    current_input[0]['expressions'] = all_des[des_idx_start]
                    #current_input[0]['grounding'] = all_des[des_idx_start]
                    current_input[0]['description_ids'] = [all_des_ids[des_idx_start]]
                    des_idx_start += 1
                else:
                    current_input = copy.deepcopy(inputs)
                    # task is "omnilabel"

                    # without chunk
                    # current_input[0]['inference_obj_descriptions'] = all_des[des_idx_start:]
                    # # current_input[0]['expressions'] = all_des[des_idx_start:]
                    # current_input[0]['description_ids'] = all_des_ids[des_idx_start:]
                    # des_idx_start += len(all_des[des_idx_start:])

                    if len(all_des[des_idx_start:]) > chunk_size:
                        current_input[0]['inference_obj_descriptions'] = all_des[des_idx_start:des_idx_start+chunk_size]
                        current_input[0]['description_ids'] = all_des_ids[des_idx_start:des_idx_start+chunk_size]
                        des_idx_start += chunk_size
                    else:
                        current_input[0]['inference_obj_descriptions'] = all_des[des_idx_start:]
                        current_input[0]['description_ids'] = all_des_ids[des_idx_start:]
                        des_idx_start += len(all_des[des_idx_start:])

                with torch.no_grad():
                    # outputs = model(inputs)
                    outputs = model(current_input)

                if torch.cuda.is_available():
                    torch.cuda.synchronize()
                total_compute_time += time.perf_counter() - start_compute_time

                start_eval_time = time.perf_counter()

                ### add to results json in omnilabel style
                # map continuous id to description id
                cont_ids_2_descript_ids = {i:v for i, v in enumerate(current_input[0]['description_ids'])}
                ##
                pred_boxes = outputs[0]['instances'].pred_boxes
                pred_boxes = BoxMode.convert(pred_boxes.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
                pred_labels = outputs[0]['instances'].pred_classes.cpu()
                pred_scores = outputs[0]['instances'].scores.cpu()

                for box_idx, box in enumerate(pred_boxes):
                    predictions.append({
                        "image_id": inputs[0]['image_id'],
                        "bbox": box.cpu().tolist(),
                        "description_ids": [cont_ids_2_descript_ids[pred_labels[box_idx].item()]],
                        "scores": [pred_scores[box_idx].item()],
                    })
                
                gc.collect()
                torch.cuda.empty_cache()
                # print(current_input[0]['inference_obj_descriptions'])
                # print(pred_boxes)
                # print(cont_ids_2_descript_ids)
            
            gc.collect()
            torch.cuda.empty_cache()

            # inference like detection
            # with torch.no_grad():
            #     # outputs = model(inputs)
            #     outputs = model(current_input)

            # if torch.cuda.is_available():
            #     torch.cuda.synchronize()
            # total_compute_time += time.perf_counter() - start_compute_time

            # start_eval_time = time.perf_counter()
            # # evaluator.process(inputs, outputs)
            # ### add to results json
            # # import pdb; pdb.set_trace()
            # # map continuous id to description id
            # cont_ids_2_descript_ids = {i:v for i, v in enumerate(inputs[0]['description_ids'])}
            # ##
            # pred_boxes = outputs[0]['instances'].pred_boxes
            # pred_boxes = BoxMode.convert(pred_boxes.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
            # pred_labels = outputs[0]['instances'].pred_classes
            # pred_scores = outputs[0]['instances'].scores

            # for box_idx, box in enumerate(pred_boxes):
            #     predictions.append({
            #         "image_id": inputs[0]['image_id'],
            #         "bbox": box.cpu().tolist(),
            #         "description_ids": [cont_ids_2_descript_ids[pred_labels[box_idx].item()]],
            #         "scores": [pred_scores[box_idx].item()],
            #     })
            
            # import pdb; pdb.set_trace()

            total_eval_time += time.perf_counter() - start_eval_time

            iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
            data_seconds_per_iter = total_data_time / iters_after_start
            compute_seconds_per_iter = total_compute_time / iters_after_start
            eval_seconds_per_iter = total_eval_time / iters_after_start
            total_seconds_per_iter = (time.perf_counter() - start_time) / iters_after_start
            if idx >= num_warmup * 2 or compute_seconds_per_iter > 5:
                eta = datetime.timedelta(seconds=int(total_seconds_per_iter * (total - idx - 1)))
                log_every_n_seconds(
                    logging.INFO,
                    (
                        f"Inference done {idx + 1}/{total}. "
                        f"Dataloading: {data_seconds_per_iter:.4f} s/iter. "
                        f"Inference: {compute_seconds_per_iter:.4f} s/iter. "
                        f"Eval: {eval_seconds_per_iter:.4f} s/iter. "
                        f"Total: {total_seconds_per_iter:.4f} s/iter. "
                        f"ETA={eta}"
                    ),
                    n=5,
                )
            start_data_time = time.perf_counter()

    # Measure the time only for this worker (before the synchronization barrier)
    total_time = time.perf_counter() - start_time
    total_time_str = str(datetime.timedelta(seconds=total_time))
    # NOTE this format is parsed by grep
    logger.info(
        "Total inference time: {} ({:.6f} s / iter per device, on {} devices)".format(
            total_time_str, total_time / (total - num_warmup), num_devices
        )
    )
    total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
    logger.info(
        "Total inference pure compute time: {} ({:.6f} s / iter per device, on {} devices)".format(
            total_compute_time_str, total_compute_time / (total - num_warmup), num_devices
        )
    )
    # import pdb; pdb.set_trace()


    # collect predictions from all GPUs
    synchronize()
    all_predictions = all_gather(predictions)
    all_predictions = list(itertools.chain(*all_predictions))
    if not is_main_process():
        return
    
    # output_folder = "./exp/omnilabel"
    output_folder = "./exp"
    result_save_json = "%s_results.json"%(dataset_name)
    results_path = os.path.join(output_folder, result_save_json)
    print('Saving to', results_path)
    json.dump(all_predictions, open(results_path, 'w'))

    from detectron2.data import MetadataCatalog
    gt_path_json = MetadataCatalog.get(dataset_name).json_file

    # evaluation
    gt = OmniLabel(gt_path_json)              # load ground truth dataset
    dt = gt.load_res(results_path)         # load prediction results
    ole = OmniLabelEval(gt, dt)
    # ole.params.resThrs = ...                    # set evaluation parameters as desired
    ole.evaluate()
    ole.accumulate()
    ret, score = ole.summarize()

    print(ret)
    return OrderedDict()

class Trainer(Trainer_detectron2):
    """
    Extension of the Trainer class adapted (at detecton2.project.train_net).
    """

    @classmethod
    def test(cls, cfg, model, evaluators=None):
        """
        Evaluate the given model. The given model is expected to already contain
        weights to evaluate.

        Args:
            cfg (CfgNode):
            model (nn.Module):
            evaluators (list[DatasetEvaluator] or None): if None, will call
                :meth:`build_evaluator`. Otherwise, must have the same length as
                ``cfg.DATASETS.TEST``.

        Returns:
            dict: a dict of result metrics
        """
        logger = logging.getLogger(__name__)
        if isinstance(evaluators, DatasetEvaluator):
            evaluators = [evaluators]
        if evaluators is not None:
            assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
                len(cfg.DATASETS.TEST), len(evaluators)
            )

        results = OrderedDict()
        for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
            data_loader = cls.build_test_loader(cfg, dataset_name)
            # When evaluators are passed in as arguments,
            # implicitly assume that evaluators can be created before data_loader.
            if evaluators is not None:
                evaluator = evaluators[idx]
            else:
                try:
                    evaluator = cls.build_evaluator(cfg, dataset_name)
                except NotImplementedError:
                    logger.warn(
                        "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
                        "or implement its `build_evaluator` method."
                    )
                    results[dataset_name] = {}
                    continue
            results_i = inference_on_dataset2(model, data_loader, evaluator, dataset_name)
            results[dataset_name] = results_i
            if comm.is_main_process():
                assert isinstance(
                    results_i, dict
                ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                    results_i
                )
                logger.info("Evaluation results for {} in csv format:".format(dataset_name))
                print_csv_format(results_i)

        if len(results) == 1:
            results = list(results.values())[0]
        return results


def main(args):
    cfg = setup(args)
    if args.eval_only:
        model = Trainer.build_model(cfg)
        #print(model)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(cfg.MODEL.WEIGHTS, resume=args.resume)
        #model.glee.text_encoder.disable_adapters()
        res = Trainer.test(cfg, model)
            
        # if comm.is_main_process():
        #     verify_results(cfg, res)
        return res
            
    if_resume = args.resume
    
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=if_resume)
    
    return trainer.train()


if __name__ == "__main__":
    args = default_argument_parser().parse_args()
    print("Command Line Args:", args)
    launch(
        main,
        args.num_gpus,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank,
        dist_url=args.dist_url,
        args=(args,),
    )
