import logging

import os
import sys
import torch
import pandas
from utils import *
from itertools import product
from config import parse_args
from collections import OrderedDict
from data_helper import create_val_datasets
from ltr.models.tracker.tracker import Tracker
from ltr.models.tracker.sttracker import STTracker

from ltr.data.test_data_builder import Sequence
from ltr.utils.evaluate import *


os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'

def run_sequence(seq: Sequence, tracker: Tracker, num_gpu=1, config=None, length=-1, dataset_name=None, ):
    """Runs a tracker on a sequence."""
    '''2021.1.2 Add multiple gpu support'''
    try:
        worker_name = torch.multiprocessing.current_process().name
        worker_id = int(worker_name[worker_name.find('-') + 1:]) - 1
        gpu_id = worker_id % num_gpu
        # torch.cuda.set_device(gpu_id)
    except:
        pass

    # 文件保存路径
    test_save_dir = config.common.test_model.replace('checkpoints', 'test')
    if length!=-1:
        test_save_dir = os.path.join(test_save_dir, f"{dataset_name}_{length}")
    else:
        test_save_dir = os.path.join(test_save_dir, dataset_name)
    video_save_file = os.path.join(test_save_dir, seq.video_name+'.txt')
    # 如果存在就不要测试了
    if os.path.exists(video_save_file) and os.path.getsize(video_save_file):
        gt = pandas.read_csv(video_save_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, low_memory=False).values
        pred_bboxes = torch.tensor(gt)
        if length!=-1:
            gt_bboxes = seq.bboxes[:length+1]
        else:
            gt_bboxes = seq.bboxes
        fps = 0

    else:
        mftracker = STTracker(config.model.name, config, dataset_name, gpu_id)
        output = tracker.run_sequence(seq, gpu_id, mftracker)

        sys.stdout.flush()

        if isinstance(output['time'][0], (dict, OrderedDict)):
            exec_time = sum([sum(times.values()) for times in output['time']])
            num_frames = len(output['time'])
        else:
            exec_time = sum(output['time'])
            num_frames = len(output['time'])

        fps = num_frames / exec_time

        pred_bboxes = output['target_bbox']
        gt_bboxes = output['origin_bbox']

        # 写文件
        pred_bboxes_pd = pandas.DataFrame(output['target_bbox'])
        pred_bboxes_pd.to_csv(video_save_file, header=None, index=False, sep=',')

    # 开始计算指标
    n_frame = len(gt_bboxes)
    # Calculate Sucess
    gt_bboxes = np.array(gt_bboxes)
    pred_bboxes = np.array(pred_bboxes)
    sucess = success_overlap(gt_bboxes, pred_bboxes, n_frame)

    # Calculate Precision
    gt_center = convert_bb_to_center(gt_bboxes)
    tracker_center = convert_bb_to_center(pred_bboxes)
    thresholds = np.arange(0, 51, 1)
    precision = success_error(gt_center, tracker_center, thresholds, n_frame)

    # Calculate Norm Precision
    gt_center_norm = convert_bb_to_norm_center(gt_bboxes, gt_bboxes[:, 2:4])
    tracker_center_norm = convert_bb_to_norm_center(pred_bboxes, gt_bboxes[:, 2:4])
    thresholds = np.arange(0, 51, 1) / 100
    norm_precision = success_error(gt_center_norm,
            tracker_center_norm, thresholds, n_frame)

    sucess = np.mean(sucess)
    precision = precision[20]
    norm_precision = norm_precision[20]

    print(f'{seq.idx+1}-gpu{gpu_id} Suc:{round(sucess,3)}, P:{round(precision,3)}, NP:{round(norm_precision,3)} FPS:{round(fps,3)}')

    torch.cuda.empty_cache()

    return [sucess, precision, norm_precision, fps]

def run_dataset(dataset, trackers, threads=1, num_gpus=1, config=None, length=-1, dataset_name=None, begin=0, end=-1):
    """Runs a list of trackers on a dataset.
    args:
        dataset: List of Sequence instances, forming a dataset.
        trackers: List of Tracker instances.
        debug: Debug level.
        threads: Number of threads to use (default 0).
    """
    # multiprocessing.set_start_method('spawn', force=True)
    ctx = torch.multiprocessing.get_context("spawn")
    # mftrackers = []
    # for i in range(num_gpus):
    #     mftracker = MixFormerTracker(config.model.name, config, dataset_name, i)
    #     mftrackers.append(mftracker)
    #     print(f"create mixformertracker on gpu:{i}")

    print('Evaluating {:4d} trackers on {:5d} sequences'.format(len(trackers), len(dataset)))

    if threads == 0:
        mode = 'sequential'
    else:
        mode = 'parallel'

    if mode == 'sequential':
        for seq in dataset:
            for tracker_info in trackers:
                run_sequence(seq, tracker_info, config)
    elif mode == 'parallel':
        # param_list = [(seq, tracker_info, mftrackers, num_gpus, config, length, dataset_name) for seq, tracker_info in product(dataset, trackers)]
        param_list = [(seq, tracker_info, num_gpus, config, length, dataset_name) for seq, tracker_info in product(dataset, trackers)]
        if end!=-1:
            param_list = param_list[begin:end]
        else:
            param_list = param_list[begin:]
        # with multiprocessing.Pool(processes=threads) as pool:
        with ctx.Pool(processes=threads*num_gpus) as pool:
            Res = pool.starmap(run_sequence, param_list)

        Sucess = []
        Precision = []
        Norm_precision = []

        for v_idx, video in enumerate(dataset):
            sucess = Res[v_idx][0]
            precision = Res[v_idx][1]
            norm_precision = Res[v_idx][2]
            fps = Res[v_idx][3]
            logger.info(f'({v_idx+1}/{len(dataset)}) Video: {video.video_name:12s} FPS: {fps:5.1f} Sucess: {sucess:.3f} Norm Precision {norm_precision:.3f} Precision {precision:.3f}')

            Sucess.append(sucess)
            Precision.append(precision)
            Norm_precision.append(norm_precision)

        Res = np.mean(Res, axis=0)
        sucess = Res[0]
        precision = Res[1]
        norm_precision = Res[2]

        # logger.info(f'======>>>>>> Dataset: {dataset_name} Sucess: {sucess:.3f} Norm Precision {norm_precision:.3f} Precision {precision:.3f}')
        logger.info(f'======>>>>>> Dataset: {dataset_name} Sucess | Norm Precision | Precision: {sucess:.3f}  |  {norm_precision:.3f}  |  {precision:.3f}')

    print('Done')
    logger.info('model verifying over time: {}'.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))


if __name__ == '__main__' :
    logger = logging.getLogger(__name__)

    gpus_num = torch.cuda.device_count()
    args = parse_args()
    args.gpus = gpus_num
    cfg = prepare_val_env(args, sys.argv)

    val_dataset, dataset_name = create_val_datasets(cfg)

    print("create val dataset over")

    tracker_name = cfg.model.name
    trackers = [Tracker(tracker_name, cfg, dataset_name, None, length=args.length)]

    run_dataset(val_dataset, trackers, args.threads, num_gpus=args.gpus, config=cfg, length=args.length, dataset_name=dataset_name, begin=args.begin, end=args.end)