from __future__ import absolute_import, division, print_function

import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import json
from PIL import Image

from ..datasets import OTB
from ..utils.metrics import rect_iou, center_error
from ..utils.viz import show_frame


class ExperimentOTB(object):
    r"""Experiment pipeline and evaluation toolkit for OTB dataset.
    
    Args:
        root_dir (string): Root directory of OTB dataset.
        version (integer or string): Specify the benchmark version, specify as one of
            ``2013``, ``2015``, ``tb50`` and ``tb100``. Default is ``2015``.
        result_dir (string, optional): Directory for storing tracking
            results. Default is ``./results``.
        report_dir (string, optional): Directory for storing performance
            evaluation results. Default is ``./reports``.
    """
    def __init__(self,
                 root_dir,
                 version=2015,
                 result_dir='results',
                 report_dir='reports'):
        super(ExperimentOTB, self).__init__()
        self.dataset = OTB(root_dir, version, download=True)
        dump_dirname = ('OTB' +
                        str(version)) if isinstance(version, int) else version
        self.result_dir = os.path.join(result_dir, dump_dirname)
        self.report_dir = os.path.join(report_dir, dump_dirname)
        # as nbins_iou increases, the success score
        # converges to the average overlap (AO)
        self.nbins_iou = 21
        self.nbins_ce = 51

    def run(self,
            tracker,
            visualize=False,
            overwrite_result=True,
            slicing_quantile=(0.0, 1.0)):
        """
        Arguments
        ---------
        overwrite_result : bool
            whether overwrite existing result or not 
        slicing_quantile : Tuple[float, float]
            quantile used for dataset slicing
        """
        print('Running tracker %s on %s...' %
              (tracker.name, type(self.dataset).__name__))

        start_quantile, end_quantile = slicing_quantile
        len_dataset = len(self.dataset)
        start_idx = int(len_dataset * start_quantile)
        end_idx = int(len_dataset * end_quantile)

        # loop over the complete dataset / dataset slice
        # for s, (img_files, anno) in enumerate(self.dataset):
        for s in range(start_idx, end_idx):
            img_files, anno = self.dataset[s]
            seq_name = self.dataset.seq_names[s]
            print('--Sequence %d/%d: %s' % (s + 1, len(self.dataset), seq_name))

            # skip if results exist
            record_file = os.path.join(self.result_dir, tracker.name,
                                       '%s.txt' % seq_name)
            if os.path.exists(record_file) and not overwrite_result:
                print('  Found results, skipping', seq_name)
                continue

            # tracking loop
            #boxes, times = tracker.track(img_files,
            ret = tracker.track(img_files, 
                                        anno[0, :],
                                         visualize=visualize)
            # assert len(boxes) == len(anno) # disabled as annotations for some benchmarks are withholded
            if len(ret) == 3:
                boxes, times, trigger_list = ret
                self._record(record_file, boxes, times, trigger_list)
            # record results
            else:
                boxes, times = ret
                self._record(record_file, boxes, times)

    def report(self, tracker_names, plot_curves=True):
        assert isinstance(tracker_names, (list, tuple))

        # assume tracker_names[0] is your tracker
        report_dir = os.path.join(self.report_dir, tracker_names[0])
        if not os.path.isdir(report_dir):
            os.makedirs(report_dir)
        report_file = os.path.join(report_dir, 'performance.json')

        performance = {}
        for name in tracker_names:
            print('Evaluating', name)
            seq_num = len(self.dataset)
            succ_curve = np.zeros((seq_num, self.nbins_iou))
            prec_curve = np.zeros((seq_num, self.nbins_ce))
            speeds = np.zeros(seq_num)

            performance.update({name: {'overall': {}, 'seq_wise': {}}})

            for s, (_, anno) in enumerate(self.dataset):
                seq_name = self.dataset.seq_names[s]
                record_file = os.path.join(self.result_dir, name,
                                           '%s.txt' % seq_name)
                boxes = np.loadtxt(record_file, delimiter=',')
                boxes[0] = anno[0]
                if not (len(boxes) == len(anno)):
                    print('warning: %s anno donnot match boxes' % seq_name)
                    len_min = min(len(boxes), len(anno))
                    boxes = boxes[:len_min]
                    anno = anno[:len_min]
                assert len(boxes) == len(anno)

                ious, center_errors = self._calc_metrics(boxes, anno)
                succ_curve[s], prec_curve[s] = self._calc_curves(
                    ious, center_errors)

                # calculate average tracking speed
                time_file = os.path.join(self.result_dir, name,
                                         'times/%s_time.txt' % seq_name)
                if os.path.isfile(time_file):
                    times = np.loadtxt(time_file)
                    times = times[times > 0]
                    if len(times) > 0:
                        speeds[s] = np.mean(1. / times)

                # store sequence-wise performance
                performance[name]['seq_wise'].update({
                    seq_name: {
                        'success_curve': succ_curve[s].tolist(),
                        'precision_curve': prec_curve[s].tolist(),
                        'success_score': np.mean(succ_curve[s]),
                        'precision_score': prec_curve[s][20],
                        'success_rate': succ_curve[s][self.nbins_iou // 2],
                        'speed_fps': speeds[s] if speeds[s] > 0 else -1
                    }
                })

            succ_curve = np.mean(succ_curve, axis=0)
            prec_curve = np.mean(prec_curve, axis=0)
            succ_score = np.mean(succ_curve)
            prec_score = prec_curve[20]
            succ_rate = succ_curve[self.nbins_iou // 2]
            if np.count_nonzero(speeds) > 0:
                avg_speed = np.sum(speeds) / np.count_nonzero(speeds)
            else:
                avg_speed = -1

            # store overall performance
            performance[name]['overall'].update({
                'success_curve':
                succ_curve.tolist(),
                'precision_curve':
                prec_curve.tolist(),
                'success_score':
                succ_score,
                'precision_score':
                prec_score,
                'success_rate':
                succ_rate,
                'speed_fps':
                avg_speed
            })

        # report the performance
        with open(report_file, 'w') as f:
            json.dump(performance, f, indent=4)
        # plot precision and success curves
        if plot_curves:
            self.plot_curves(tracker_names)

        return performance

    def show(self, tracker_names, seq_names=None, play_speed=1):
        if seq_names is None:
            seq_names = self.dataset.seq_names
        elif isinstance(seq_names, str):
            seq_names = [seq_names]
        assert isinstance(tracker_names, (list, tuple))
        assert isinstance(seq_names, (list, tuple))

        play_speed = int(round(play_speed))
        assert play_speed > 0

        for s, seq_name in enumerate(seq_names):
            print('[%d/%d] Showing results on %s...' %
                  (s + 1, len(seq_names), seq_name))

            # load all tracking results
            records = {}
            for name in tracker_names:
                record_file = os.path.join(self.result_dir, name,
                                           '%s.txt' % seq_name)
                records[name] = np.loadtxt(record_file, delimiter=',')

            # loop over the sequence and display results
            img_files, anno = self.dataset[seq_name]
            for f, img_file in enumerate(img_files):
                if not f % play_speed == 0:
                    continue
                image = Image.open(img_file)
                boxes = [anno[f]] + [records[name][f] for name in tracker_names]
                show_frame(image,
                           boxes,
                           legends=['GroundTruth'] + tracker_names,
                           colors=[
                               'w', 'r', 'g', 'b', 'c', 'm', 'y', 'orange',
                               'purple', 'brown', 'pink'
                           ])

    def _record(self, record_file, boxes, times, trigger_list=None):
        # record bounding boxes
        record_dir = os.path.dirname(record_file)
        if not os.path.isdir(record_dir):
            os.makedirs(record_dir)
        np.savetxt(record_file, boxes, fmt='%.3f', delimiter=',')
        while not os.path.exists(record_file):
            print('warning: recording failed, retrying...')
            np.savetxt(record_file, boxes, fmt='%.3f', delimiter=',')
        print('  Results recorded at', record_file)

        # record running times
        time_dir = os.path.join(record_dir, 'times')
        if not os.path.isdir(time_dir):
            os.makedirs(time_dir)
        time_file = os.path.join(
            time_dir,
            os.path.basename(record_file).replace('.txt', '_time.txt'))
        np.savetxt(time_file, times, fmt='%.8f')
        
        # record trigger list
        if trigger_list is not None:
            trig_dir = os.path.join(record_dir, 'triggers')
            if not os.path.isdir(trig_dir):
                os.makedirs(trig_dir)
            trig_file = os.path.join(
                trig_dir,
                os.path.basename(record_file).replace('.txt', '_trig.txt'))
            np.savetxt(trig_file, trigger_list, fmt='%d')


        # record trigger list
        if trigger_list is not None:
            trig_dir = os.path.join(record_dir, 'triggers')
            if not os.path.isdir(trig_dir):
                os.makedirs(trig_dir)
            trig_file = os.path.join(
                trig_dir,
                os.path.basename(record_file).replace('.txt', '_trig.txt'))
            np.savetxt(trig_file, trigger_list, fmt='%d')

    def _calc_metrics(self, boxes, anno):
        # can be modified by children classes
        ious = rect_iou(boxes, anno)
        center_errors = center_error(boxes, anno)
        return ious, center_errors

    def _calc_curves(self, ious, center_errors):
        ious = np.asarray(ious, float)[:, np.newaxis]
        center_errors = np.asarray(center_errors, float)[:, np.newaxis]

        thr_iou = np.linspace(0, 1, self.nbins_iou)[np.newaxis, :]
        thr_ce = np.arange(0, self.nbins_ce)[np.newaxis, :]

        bin_iou = np.greater(ious, thr_iou)
        bin_ce = np.less_equal(center_errors, thr_ce)

        succ_curve = np.mean(bin_iou, axis=0)
        prec_curve = np.mean(bin_ce, axis=0)

        return succ_curve, prec_curve

    def plot_curves(self, tracker_names):
        # assume tracker_names[0] is your tracker
        report_dir = os.path.join(self.report_dir, tracker_names[0])
        assert os.path.exists(report_dir), \
            'No reports found. Run "report" first' \
            'before plotting curves.'
        report_file = os.path.join(report_dir, 'performance.json')
        assert os.path.exists(report_file), \
            'No reports found. Run "report" first' \
            'before plotting curves.'

        # load pre-computed performance
        with open(report_file) as f:
            performance = json.load(f)

        succ_file = os.path.join(report_dir, 'success_plots.png')
        prec_file = os.path.join(report_dir, 'precision_plots.png')
        key = 'overall'

        # markers
        markers = ['-', '--', '-.']
        markers = [c + m for m in markers for c in [''] * 10]

        # sort trackers by success score
        tracker_names = list(performance.keys())
        succ = [t[key]['success_score'] for t in performance.values()]
        inds = np.argsort(succ)[::-1]
        tracker_names = [tracker_names[i] for i in inds]

        # plot success curves
        thr_iou = np.linspace(0, 1, self.nbins_iou)
        fig, ax = plt.subplots()
        lines = []
        legends = []
        for i, name in enumerate(tracker_names):
            line, = ax.plot(thr_iou, performance[name][key]['success_curve'],
                            markers[i % len(markers)])
            lines.append(line)
            legends.append('%s: [%.3f]' %
                           (name, performance[name][key]['success_score']))
        matplotlib.rcParams.update({'font.size': 7.4})
        legend = ax.legend(lines,
                           legends,
                           loc='center left',
                           bbox_to_anchor=(1, 0.5))

        matplotlib.rcParams.update({'font.size': 9})
        ax.set(xlabel='Overlap threshold',
               ylabel='Success rate',
               xlim=(0, 1),
               ylim=(0, 1),
               title='Success plots of OPE')
        ax.grid(True)
        fig.tight_layout()

        print('Saving success plots to', succ_file)
        fig.savefig(succ_file,
                    bbox_extra_artists=(legend, ),
                    bbox_inches='tight',
                    dpi=300)

        # sort trackers by precision score
        tracker_names = list(performance.keys())
        prec = [t[key]['precision_score'] for t in performance.values()]
        inds = np.argsort(prec)[::-1]
        tracker_names = [tracker_names[i] for i in inds]

        # plot precision curves
        thr_ce = np.arange(0, self.nbins_ce)
        fig, ax = plt.subplots()
        lines = []
        legends = []
        for i, name in enumerate(tracker_names):
            line, = ax.plot(thr_ce, performance[name][key]['precision_curve'],
                            markers[i % len(markers)])
            lines.append(line)
            legends.append('%s: [%.3f]' %
                           (name, performance[name][key]['precision_score']))
        matplotlib.rcParams.update({'font.size': 7.4})
        legend = ax.legend(lines,
                           legends,
                           loc='center left',
                           bbox_to_anchor=(1, 0.5))

        matplotlib.rcParams.update({'font.size': 9})
        ax.set(xlabel='Location error threshold',
               ylabel='Precision',
               xlim=(0, thr_ce.max()),
               ylim=(0, 1),
               title='Precision plots of OPE')
        ax.grid(True)
        fig.tight_layout()

        print('Saving precision plots to', prec_file)
        fig.savefig(prec_file, dpi=300)
