# --------------------------------------------------------
# Python Single Object Tracking Evaluation
# Licensed under The MIT License [see LICENSE for details]
# Written by Fangyi Zhang
# @author fangyi.zhang@vipl.ict.ac.cn
# @project https://github.com/StrangerZhang/pysot-toolkit.git
# Revised for SiamMask by foolwood
# --------------------------------------------------------

import itertools
import warnings

import numpy as np

from ..utils import calculate_accuracy, calculate_failures


class AccuracyRobustnessBenchmark:
    """
    Args:
        dataset:
        burnin:
    """
    def __init__(self, dataset, burnin=10):
        self.dataset = dataset
        self.burnin = burnin

    def eval(self, eval_trackers=None):
        """
        Args:
            eval_tags: list of tag
            eval_trackers: list of tracker name
        Returns:
            ret: dict of results
        """
        if eval_trackers is None:
            eval_trackers = self.dataset.tracker_names
        if isinstance(eval_trackers, str):
            eval_trackers = [eval_trackers]

        result = {}
        for tracker_name in eval_trackers:
            accuracy, failures = self._calculate_accuracy_robustness(
                tracker_name)
            result[tracker_name] = {'overlaps': accuracy, 'failures': failures}
        return result

    def show_result(self,
                    result,
                    eao_result=None,
                    show_video_level=False,
                    helight_threshold=0.5):
        """pretty print result
        Args:
            result: returned dict from function eval
        """
        tracker_name_len = max((max([len(x) for x in result.keys()]) + 2), 12)
        if eao_result is not None:
            header = "|{:^" + str(
                tracker_name_len) + "}|{:^10}|{:^12}|{:^13}|{:^7}|"
            header = header.format('Tracker Name', 'Accuracy', 'Robustness',
                                   'Lost Number', 'EAO')
            formatter = "|{:^" + str(
                tracker_name_len) + "}|{:^10.3f}|{:^12.3f}|{:^13.1f}|{:^7.3f}|"
        else:
            header = "|{:^" + str(tracker_name_len) + "}|{:^10}|{:^12}|{:^13}|"
            header = header.format('Tracker Name', 'Accuracy', 'Robustness',
                                   'Lost Number')
            formatter = "|{:^" + str(
                tracker_name_len) + "}|{:^10.3f}|{:^12.3f}|{:^13.1f}|"
        bar = '-' * len(header)
        print(bar)
        print(header)
        print(bar)
        if eao_result is not None:
            tracker_eao = sorted(eao_result.items(),
                                 key=lambda x: x[1]['all'],
                                 reverse=True)[:20]
            tracker_names = [x[0] for x in tracker_eao]
        else:
            tracker_names = list(result.keys())
        for tracker_name in tracker_names:
            ret = result[tracker_name]
            overlaps = list(itertools.chain(*ret['overlaps'].values()))
            accuracy = np.nanmean(overlaps)
            length = sum([len(x) for x in ret['overlaps'].values()])
            failures = list(ret['failures'].values())
            lost_number = np.mean(np.sum(failures, axis=0))
            robustness = np.mean(
                np.sum(np.array(failures), axis=0) / length) * 100
            if eao_result is None:
                print(
                    formatter.format(tracker_name, accuracy, robustness,
                                     lost_number))
            else:
                print(
                    formatter.format(tracker_name, accuracy, robustness,
                                     lost_number,
                                     eao_result[tracker_name]['all']))
        print(bar)

        if show_video_level and len(result) < 10:
            print('\n\n')
            header1 = "|{:^14}|".format("Tracker name")
            header2 = "|{:^14}|".format("Video name")
            for tracker_name in result.keys():
                header1 += ("{:^17}|").format(tracker_name)
                header2 += "{:^8}|{:^8}|".format("Acc", "LN")
            print('-' * len(header1))
            print(header1)
            print('-' * len(header1))
            print(header2)
            print('-' * len(header1))
            videos = list(result[tracker_name]['overlaps'].keys())
            for video in videos:
                row = "|{:^14}|".format(video)
                for tracker_name in result.keys():
                    overlaps = result[tracker_name]['overlaps'][video]
                    accuracy = np.nanmean(overlaps)
                    failures = result[tracker_name]['failures'][video]
                    lost_number = np.mean(failures)

                    accuracy_str = "{:^8.3f}".format(accuracy)
                    if accuracy < helight_threshold:
                        #row += f'{Fore.RED}{accuracy_str}{Style.RESET_ALL}|'
                        row += '{Fore.RED}{accuracy_str}{Style.RESET_ALL}|'
                    else:
                        row += accuracy_str + '|'
                    lost_num_str = "{:^8.3f}".format(lost_number)
                    if lost_number > 0:
                        #row += f'{Fore.RED}{lost_num_str}{Style.RESET_ALL}|'
                        row += '{Fore.RED}{lost_num_str}{Style.RESET_ALL}|'
                    else:
                        row += lost_num_str + '|'
                print(row)
            print('-' * len(header1))

    def write_result(self,
                     result,
                     eao_result=None,
                     show_video_level=False,
                     helight_threshold=0.5,
                     result_file=None):
        """pretty result_file.write result
        Args:
            result: returned dict from function eval
        """
        tracker_name_len = max((max([len(x) for x in result.keys()]) + 2), 12)
        if eao_result is not None:
            header = "|{:^" + str(
                tracker_name_len) + "}|{:^10}|{:^12}|{:^13}|{:^7}|"
            header = header.format('Tracker Name', 'Accuracy', 'Robustness',
                                   'Lost Number', 'EAO')
            formatter = "|{:^" + str(
                tracker_name_len) + "}|{:^10.3f}|{:^12.3f}|{:^13.1f}|{:^7.3f}|"
        else:
            header = "|{:^" + str(tracker_name_len) + "}|{:^10}|{:^12}|{:^13}|"
            header = header.format('Tracker Name', 'Accuracy', 'Robustness',
                                   'Lost Number')
            formatter = "|{:^" + str(
                tracker_name_len) + "}|{:^10.3f}|{:^12.3f}|{:^13.1f}|"
        bar = '-' * len(header)
        result_file.write(bar + '\n')
        result_file.write(header + '\n')
        result_file.write(bar + '\n')
        if eao_result is not None:
            tracker_eao = sorted(eao_result.items(),
                                 key=lambda x: x[1]['all'],
                                 reverse=True)[:20]
            tracker_names = [x[0] for x in tracker_eao]
        else:
            tracker_names = list(result.keys())
        for tracker_name in tracker_names:
            ret = result[tracker_name]
            overlaps = list(itertools.chain(*ret['overlaps'].values()))
            accuracy = np.nanmean(overlaps)
            length = sum([len(x) for x in ret['overlaps'].values()])
            failures = list(ret['failures'].values())
            lost_number = np.mean(np.sum(failures, axis=0))
            robustness = np.mean(
                np.sum(np.array(failures), axis=0) / length) * 100
            if eao_result is None:
                result_file.write(
                    formatter.format(tracker_name, accuracy, robustness,
                                     lost_number) + '\n')
            else:
                result_file.write(
                    formatter.format(tracker_name, accuracy, robustness,
                                     lost_number, eao_result[tracker_name]
                                     ['all']) + '\n')
        result_file.write(bar + '\n')

        if show_video_level and len(result) < 10:
            result_file.write('\n\n')
            header1 = "|{:^14}|".format("Tracker name")
            header2 = "|{:^14}|".format("Video name")
            for tracker_name in result.keys():
                header1 += ("{:^17}|").format(tracker_name)
                header2 += "{:^8}|{:^8}|".format("Acc", "LN")
            result_file.write('-' * len(header1) + '\n')
            result_file.write(header1 + '\n')
            result_file.write('-' * len(header1) + '\n')
            result_file.write(header2 + '\n')
            result_file.write('-' * len(header1) + '\n')
            videos = list(result[tracker_name]['overlaps'].keys())
            for video in videos:
                row = "|{:^14}|".format(video)
                for tracker_name in result.keys():
                    overlaps = result[tracker_name]['overlaps'][video]
                    accuracy = np.nanmean(overlaps)
                    failures = result[tracker_name]['failures'][video]
                    lost_number = np.mean(failures)

                    accuracy_str = "{:^8.3f}".format(accuracy)
                    if accuracy < helight_threshold:
                        # row += f'{Fore.RED}{accuracy_str}{Style.RESET_ALL}|'
                        row += '{Fore.RED}{accuracy_str}{Style.RESET_ALL}|'
                    else:
                        row += accuracy_str + '|'
                    lost_num_str = "{:^8.3f}".format(lost_number)
                    if lost_number > 0:
                        # row += f'{Fore.RED}{lost_num_str}{Style.RESET_ALL}|'
                        row += '{Fore.RED}{lost_num_str}{Style.RESET_ALL}|'
                    else:
                        row += lost_num_str + '|'
                result_file.write(row + '\n')
            result_file.write('-' * len(header1) + '\n')

    def _calculate_accuracy_robustness(self, tracker_name):
        overlaps = {}
        failures = {}
        for i in range(len(self.dataset)):
            video = self.dataset[i]
            gt_traj = video.gt_traj
            if tracker_name not in video.pred_trajs:
                tracker_trajs = video.load_tracker(self.dataset.tracker_path,
                                                   tracker_name, False)
            else:
                tracker_trajs = video.pred_trajs[tracker_name]
            overlaps_group = []
            num_failures_group = []
            for tracker_traj in tracker_trajs:
                num_failures = calculate_failures(tracker_traj)[0]
                overlaps_ = calculate_accuracy(tracker_traj,
                                               gt_traj,
                                               burnin=10,
                                               bound=(video.width,
                                                      video.height))[1]
                overlaps_group.append(overlaps_)
                num_failures_group.append(num_failures)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=RuntimeWarning)
                overlaps[video.name] = np.nanmean(overlaps_group,
                                                  axis=0).tolist()
                failures[video.name] = num_failures_group
        return overlaps, failures
