import os
import time
import lmdb
import torch
from tqdm import tqdm
import cv2 as cv
import numpy as np
from collections import OrderedDict

LMDB_ENVS = dict()
LMDB_HANDLES = dict()
LMDB_FILELISTS = dict()


def get_lmdb_handle(name):
    global LMDB_HANDLES, LMDB_FILELISTS
    item = LMDB_HANDLES.get(name, None)
    if item is None:
        env = lmdb.open(name, readonly=True, lock=False, readahead=False, meminit=False)
        LMDB_ENVS[name] = env
        item = env.begin(write=False)
        LMDB_HANDLES[name] = item

    return item

def decode_img(lmdb_fname, key_name):
    handle = get_lmdb_handle(lmdb_fname)
    binfile = handle.get(key_name.encode())
    if binfile is None:
        print("Illegal data detected. %s %s" % (lmdb_fname, key_name))
    s = np.frombuffer(binfile, np.uint8)
    x = cv.cvtColor(cv.imdecode(s, cv.IMREAD_COLOR), cv.COLOR_BGR2RGB)
    return x


class Tracker:
    """Wraps the tracker for evaluation and running purposes.
    args:
        name: Name of tracking method.
        parameter_name: Name of parameter file.
        run_id: The run id.
        display_name: Name to be displayed in the result plots.
    """

    def __init__(self, name: str, cfg: str, dataset_name: str, run_id: int = None, display_name: str = None, length: int = -1):
        # assert run_id is None or isinstance(run_id, int)

        self.name = name
        self.cfg = cfg
        self.dataset_name = dataset_name
        self.run_id = run_id
        self.display_name = display_name

        self.length = length


    def run_sequence(self, seq, gpu_id, tracker):
        """Run tracker on sequence.
        args:
            seq: Sequence to run the tracker on.
            visualization: Set visualization flag (None means default value specified in the parameters).
            debug: Set debug level (None means default value specified in the parameters).
            multiobj_mode: Which mode to use for multiple objects.
        """

        # Get init information
        init_info = seq.init_info


        output = self._track_sequence(tracker, seq, init_info)
        return output

    def _track_sequence(self, tracker, seq, init_info):

        output = {'target_bbox': [],
                  'time': [],
                  'origin_bbox': []}
        if self.cfg.test.save_all_boxes:
            output['all_boxes'] = []
            output['all_scores'] = []

        def _store_outputs(tracker_out: dict, defaults=None):
            defaults = {} if defaults is None else defaults
            for key in output.keys():
                val = tracker_out.get(key, defaults.get(key, None))
                if key in tracker_out or val is not None:
                    output[key].append(val)

        # Initialize
        image = self._read_image(seq.frames[0])

        start_time = time.time()
        out = tracker.initialize(image, init_info)
        tracker.tokenize_text(seq.nlp)
        if out is None:
            out = {}

        prev_output = OrderedDict(out)
        init_default = {'target_bbox': init_info.get('init_bbox'),
                        'time': time.time() - start_time,
                        'origin_bbox': init_info.get('init_bbox')}
        if self.cfg.test.save_all_boxes:
            init_default['all_boxes'] = out['all_boxes']
            init_default['all_scores'] = out['all_scores']

        _store_outputs(out, init_default)

        worker_name = torch.multiprocessing.current_process().name
        worker_id = worker_name[worker_name.find('-') + 1:]

        for frame_num, frame_path in tqdm(enumerate(seq.frames[1:], start=1), desc=worker_id):
            if self.length!=-1 and frame_num>self.length:
                break
            image = self._read_image(frame_path)

            start_time = time.time()

            info = seq.frame_info[frame_num]
            info['previous_output'] = prev_output

            out = tracker.track(image, info)
            out.setdefault('origin_bbox', seq.bboxes[frame_num])
            prev_output = OrderedDict(out)
            _store_outputs(out, {'time': time.time() - start_time})

        for key in ['target_bbox', 'all_boxes', 'all_scores']:
            if key in output and len(output[key]) <= 1:
                output.pop(key)

        return output


    def _read_image(self, image_file: str):
        if isinstance(image_file, str):
            im = cv.imread(image_file)
            return cv.cvtColor(im, cv.COLOR_BGR2RGB)
        elif isinstance(image_file, list) and len(image_file) == 2:
            return decode_img(image_file[0], image_file[1])
        else:
            raise ValueError("type of image_file should be str or list")



