import os
import sys
import argparse

prj_path = os.path.join(os.path.dirname(__file__), '..')
if prj_path not in sys.path:
    sys.path.append(prj_path)

from lib.test.evaluation import get_dataset
from lib.test.evaluation.running import run_dataset
from lib.test.evaluation.tracker import Tracker

from lib.test.evaluation.environment import env_settings
env = env_settings()

import warnings
warnings.filterwarnings("ignore")


def run_tracker(tracker_name, tracker_param, save_name, vis_attn, vis_bbox, run_id=None, dataset_name='otb', sequence=None, 
                debug=0, threads=0, num_gpus=8):
    """Run tracker on sequence or dataset.
    args:
        tracker_name: Name of tracking method.
        tracker_param: Name of parameter file.
        run_id: The run id.
        dataset_name: Name of dataset (otb, nfs, uav, tpl, vot, tn, gott, gotv, lasot).
        sequence: Sequence number or name.
        debug: Debug level.
        threads: Number of threads.
    """

    dataset = get_dataset(dataset_name)

    if sequence is not None:
        dataset = [dataset[sequence]]

    trackers = [Tracker(tracker_name, tracker_param, dataset_name, save_name, run_id, vis_attn, vis_bbox)]

    run_dataset(dataset, trackers, debug, threads, num_gpus)


def main():
    parser = argparse.ArgumentParser(description='Run tracker on sequence or dataset.')
    parser.add_argument('tracker_name', type=str, help='Name of tracking method.')
    parser.add_argument('tracker_param', type=str, help='Name of config file.')
    parser.add_argument('--test_epoch', type=str, default=100, help='The param epoch for test.')
    parser.add_argument('--vis_attn', type=int, default=0)
    parser.add_argument('--vis_bbox', type=int, default=0)
    parser.add_argument('--runid', type=int, default=None, help='The run id.')
    parser.add_argument('--dataset_name', type=str, default='otb', help='Name of dataset (otb, nfs, uav, tpl, vot, tn, gott, gotv, lasot).')
    parser.add_argument('--sequence', type=str, default=None, help='Sequence number or name.')
    parser.add_argument('--debug', type=int, default=0, help='Debug level.')
    parser.add_argument('--threads', type=int, default=0, help='Number of threads.')
    parser.add_argument('--num_gpus', type=int, default=8)
    

    args = parser.parse_args()

    test_epoch = [int(epoch) for epoch in args.test_epoch.split(',')]

    for epoch in test_epoch:
        save_name = 'ep{:04d}'.format(epoch)

        set_epoch(epoch, args.tracker_name, args.tracker_param)

        try:
            seq_name = int(args.sequence)
        except:
            seq_name = args.sequence

        run_tracker(args.tracker_name, args.tracker_param, save_name, args.vis_attn, args.vis_bbox, args.runid, args.dataset_name, 
                    seq_name, args.debug, args.threads, args.num_gpus)

import yaml
def set_epoch(test_epoch, tracker_name, tracker_param): 
    path = os.path.join(env.config_path, tracker_name, tracker_param + '.yaml')
    with open(path,'r') as f:
        doc = yaml.safe_load(f)

    doc['TEST']['EPOCH'] = test_epoch

    with open(path, 'w') as f:
        yaml.dump(doc, f)

if __name__ == '__main__':
    main()