import os
import sys
import torch
from utils import *
from itertools import product
from config import parse_args
from collections import OrderedDict
from data_helper import create_val_datasets
from ltr.models.tracker.tracker import Tracker
from ltr.models.tracker.mixformer_tracker import MixFormerTracker

from ltr.data.test_data_builder import Sequence
from ltr.utils.evaluate import *


os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'

# tnl2k 
# test_003_BianLian_video_04(appearance variations), test_009_Wukong_video_p709_done(occlusion), test_011_SpiderMan_test_001_done(view point chaning, fast motion)
# test_012_SpaceShip_video_34_done(semantics understanding, appearance variations)
# test_014_SpaceShip_video_02_done(appearance variations,fast motion, occlusion, semantics understanding,view point chaning,amibiguity)

vis_list = ['test_024_ManMiddle_video_01_done']

def run_sequence(seq: Sequence, tracker: Tracker, num_gpu=1, config=None, length=-1, dataset_name=None, ):
    """Runs a tracker on a sequence."""
    '''2021.1.2 Add multiple gpu support'''
    try:
        worker_name = torch.multiprocessing.current_process().name
        worker_id = int(worker_name[worker_name.find('-') + 1:]) - 1
        gpu_id = worker_id % num_gpu
        # torch.cuda.set_device(gpu_id)
    except:
        pass
    
    
    if seq.video_name in vis_list:
        mftracker = MixFormerTracker(config.model.name, config, dataset_name, gpu_id)
        output = tracker.run_sequence(seq, gpu_id, mftracker)

    sys.stdout.flush()

    torch.cuda.empty_cache()

    return None

def run_dataset(dataset, trackers, threads=1, num_gpus=1, config=None, length=-1, dataset_name=None, begin=0, end=-1):
    """Runs a list of trackers on a dataset.
    args:
        dataset: List of Sequence instances, forming a dataset.
        trackers: List of Tracker instances.
        debug: Debug level.
        threads: Number of threads to use (default 0).
    """
    # multiprocessing.set_start_method('spawn', force=True)
    ctx = torch.multiprocessing.get_context("spawn")

    print('Visualization {:4d} trackers on {:5d} sequences'.format(len(trackers), len(dataset)))

    if threads == 0:
        mode = 'sequential'
    else:
        mode = 'parallel'

    if mode == 'sequential':
        for seq in dataset:
            for tracker_info in trackers:
                run_sequence(seq, tracker_info, config)
    elif mode == 'parallel':
        # param_list = [(seq, tracker_info, mftrackers, num_gpus, config, length, dataset_name) for seq, tracker_info in product(dataset, trackers)]
        param_list = [(seq, tracker_info, num_gpus, config, length, dataset_name) for seq, tracker_info in product(dataset, trackers)]
        if end!=-1:
            param_list = param_list[begin:end]
        else:
            param_list = param_list[begin:]
        # with multiprocessing.Pool(processes=threads) as pool:
        with ctx.Pool(processes=threads*num_gpus) as pool:
            pool.starmap(run_sequence, param_list)





if __name__ == '__main__' :

    gpus_num = torch.cuda.device_count()
    args = parse_args()
    args.gpus = gpus_num
    cfg = prepare_val_env(args, sys.argv)

    val_dataset, dataset_name = create_val_datasets(cfg)


    tracker_name = cfg.model.name
    trackers = [Tracker(tracker_name, cfg, dataset_name, None, length=args.length)]

    run_dataset(val_dataset, trackers, args.threads, num_gpus=args.gpus, config=cfg, length=args.length, dataset_name=dataset_name, begin=args.begin, end=args.end)