# %%
from bs_utils import *
from copy import deepcopy
from tqdm import trange
from tqdm import tqdm


class SUMStat:
    """ A class used to get stats of SUM trained data """

    def __init__(self, path):
        self.path = path
        self.data = read_pickle(path)
        self.sample_id = list(self.data.keys())[0]
        self.sample_sys = list(self.data[self.sample_id]['sys_summs'].keys())[0]
        self._metrics = list(self.data[self.sample_id]['sys_summs'][self.sample_sys]['scores'].keys())
        self._auto_metrics = [x for x in self.metrics if x not in self.human_metrics]

    def save_data(self, path=None):
        if path is None:
            path = self.path
        save_pickle(self.data, path)

    def evaluate_summary(self, human_metric, auto_metrics=None, table=None):
        """ Evaluate summaries. Conduct summary-level correlations w.r.t each document """
        assert human_metric in self.human_metrics
        if auto_metrics is None:
            auto_metrics = self.auto_metrics
        print(f'Human metric: {human_metric}')
        headers = ['metric', 'spearman', 'kendalltau']
        metric_with_corr = []
        for metric in auto_metrics:
            correlations = []
            for doc_id in self.data:
                target_scores = []
                prediction_scores = []

                sys_summs = self.data[doc_id]['sys_summs']
                for sys_name in sys_summs:
                    prediction_scores.append(sys_summs[sys_name]['scores'][metric])
                    target_scores.append(sys_summs[sys_name]['scores'][human_metric])
                if len(set(prediction_scores)) == 1 or len(set(target_scores)) == 1:
                    continue
                correlations.append([spearmanr(target_scores, prediction_scores)[0],
                                     kendalltau(target_scores, prediction_scores)[0]])
            corr_mat = np.array(correlations)
            spearman, ktau = np.mean(corr_mat[:, 0]), np.mean(corr_mat[:, 1])
            metric_with_corr.append([metric, spearman, ktau])
        sorted_metric_with_corr = sorted(metric_with_corr, key=lambda x: x[1], reverse=True)
        if table is not None:
            file = open(table, 'w')
            for each in sorted_metric_with_corr:
                print(f'{each[0]}\t{each[1]}\t{each[2]}', file=file)
            file.flush()
        print(tabulate(sorted_metric_with_corr, headers=headers, tablefmt='simple'))

    def get_fact_pearson(self, auto_metrics=None):
        assert 'QAGS' in self.path
        headers = ['metric', 'pearson']
        metric_with_corr = []
        if auto_metrics is None:
            auto_metrics = self.auto_metrics
        for metric in auto_metrics:
            human_scores = []
            metric_scores = []
            for doc_id in self.data:
                human_scores.append(self.data[doc_id]['sys_summs'][0]['scores']['fact'])
                metric_scores.append(self.data[doc_id]['sys_summs'][0]['scores'][metric])
            pearson, _ = pearsonr(human_scores, metric_scores)
            metric_with_corr.append([metric, pearson])
        metric_with_corr = sorted(metric_with_corr, key=lambda x: x[1], reverse=True)
        print(tabulate(metric_with_corr, headers=headers, tablefmt='simple'))

    def fact_pearson_sig_test(self, metric_list):
        for m in metric_list:
            assert m in self.auto_metrics
        comp_tab = np.zeros((len(metric_list), len(metric_list)), dtype=int)
        for i in range(len(metric_list)):  # row
            for j in range(i + 1, len(metric_list)):  # col
                m1 = metric_list[i]
                m2 = metric_list[j]
                # Test if m1 is significant better than m2
                out = self.fact_pearson_sig_test_two(m1, m2)
                if out == 1:
                    comp_tab[j][i] = 1
                elif out == -1:
                    comp_tab[i][j] = 1
                else:
                    pass
        result = comp_tab.sum(axis=1)
        best_metrics = []
        for i in range(len(result)):
            if result[i] == 0:
                best_metrics.append(metric_list[i])
        print(f'Best metrics are: {best_metrics}')

    def fact_pearson_sig_test_two(self, m1, m2):
        assert 'QAGS' in self.path
        random.seed(666)
        doc_ids = list(self.data.keys())
        better = 0
        for i in trange(1000):
            random.shuffle(doc_ids)
            sub_ids = doc_ids[:int(0.8 * len(doc_ids))]
            m1_scores, m2_scores, human_scores = [], [], []

            for doc_id in sub_ids:
                human_scores.append(self.data[doc_id]['sys_summs'][0]['scores']['fact'])
                m1_scores.append(self.data[doc_id]['sys_summs'][0]['scores'][m1])
                m2_scores.append(self.data[doc_id]['sys_summs'][0]['scores'][m2])
            pearson_m1, _ = pearsonr(human_scores, m1_scores)
            pearson_m2, _ = pearsonr(human_scores, m2_scores)
            if pearson_m1 > pearson_m2:
                better += 1
        if better > 950:
            return 1
        elif better < 50:
            return -1
        else:
            return 0

    def get_fact_acc(self, auto_metrics=None):
        """ Used for the Rank19 dataset. """
        assert 'Rank' in self.path
        headers = ['metric', 'acc']
        metric_with_acc = []
        if auto_metrics is None:
            auto_metrics = self.auto_metrics
        for metric in auto_metrics:
            correct = 0
            for doc_id in self.data:
                if self.data[doc_id]['sys_summs']['correct']['scores'][metric] > \
                        self.data[doc_id]['sys_summs']['incorrect']['scores'][metric]:
                    correct += 1
            metric_with_acc.append([metric, correct / len(self.data)])
        metric_with_acc = sorted(metric_with_acc, key=lambda x: x[1], reverse=True)
        print(tabulate(metric_with_acc, headers=headers, tablefmt='simple'))

    def fact_acc_sig_test(self, metric_list):
        for m in metric_list:
            assert m in self.auto_metrics
        comp_tab = np.zeros((len(metric_list), len(metric_list)), dtype=int)
        for i in range(len(metric_list)):  # row
            for j in range(i + 1, len(metric_list)):  # col
                m1 = metric_list[i]
                m2 = metric_list[j]
                # Test if m1 is significant better than m2
                out = self.fact_acc_sig_test_two(m1, m2)
                if out == 1:
                    comp_tab[j][i] = 1
                elif out == -1:
                    comp_tab[i][j] = 1
                else:
                    pass
        result = comp_tab.sum(axis=1)
        best_metrics = []
        for i in range(len(result)):
            if result[i] == 0:
                best_metrics.append(metric_list[i])
        print(f'Best metrics are: {best_metrics}')

    def fact_acc_sig_test_two(self, m1, m2):
        """ Return 1 if m1 significant better than m2, or -1 if m1 significant worse than m2
            or 0 if cannot decide.
        """
        assert 'Rank' in self.path
        random.seed(666)
        doc_ids = list(self.data.keys())
        better = 0
        for i in trange(1000):
            random.shuffle(doc_ids)
            sub_ids = doc_ids[:int(0.8 * len(doc_ids))]
            m1_correct = 0
            m2_correct = 0
            for doc_id in sub_ids:
                if self.data[doc_id]['sys_summs']['correct']['scores'][m1] > \
                        self.data[doc_id]['sys_summs']['incorrect']['scores'][m1]:
                    m1_correct += 1
                if self.data[doc_id]['sys_summs']['correct']['scores'][m2] > \
                        self.data[doc_id]['sys_summs']['incorrect']['scores'][m2]:
                    m2_correct += 1
            if m1_correct > m2_correct:
                better += 1
        if better > 950:
            return 1
        elif better < 50:
            return -1
        else:
            return 0

    def sig_test(self, metric_list, human_metric):
        """ Comparisons between all pairs of metrics. Using Spearman correlation. """
        for m in metric_list:
            assert m in self.auto_metrics
        comp_tab = np.zeros((len(metric_list), len(metric_list)), dtype=int)
        for i in range(len(metric_list)):  # row
            for j in range(i + 1, len(metric_list)):  # col
                m1 = metric_list[i]
                m2 = metric_list[j]
                # Test if m1 is significant better than m2
                out = self.sig_test_two(m1, m2, human_metric)
                if out == 1:
                    comp_tab[j][i] = 1
                elif out == -1:
                    comp_tab[i][j] = 1
                else:
                    pass
        result = comp_tab.sum(axis=1)
        best_metrics = []
        for i in range(len(result)):
            if result[i] == 0:
                best_metrics.append(metric_list[i])
        print(f'Best metrics are: {best_metrics}')

    def sig_test_two(self, m1, m2, human_metric):
        """ Comparisons between a pair of metrics. Using Spearman correlation.
            Test if m1 is significant better than m2. return 1 if m1 is better,
            return -1 if m2 is better, otherwise return 0
        """
        assert (not 'Rank' in self.path) and (not 'QAGS' in self.path)
        random.seed(666)
        doc_ids = list(self.data.keys())
        better = 0
        for i in trange(1000):
            random.shuffle(doc_ids)
            sub_ids = doc_ids[:int(0.8 * len(doc_ids))]
            corr1, corr2 = [], []
            for doc_id in sub_ids:
                target, pred1, pred2 = [], [], []
                sys_summs = self.data[doc_id]['sys_summs']
                for sys_name in sys_summs:
                    pred1.append(sys_summs[sys_name]['scores'][m1])
                    pred2.append(sys_summs[sys_name]['scores'][m2])
                    target.append(sys_summs[sys_name]['scores'][human_metric])
                if len(set(pred1)) == 1 or len(set(pred2)) == 1 or len(set(target)) == 1:
                    continue
                corr1.append(spearmanr(target, pred1)[0])
                corr2.append(spearmanr(target, pred2)[0])

            corr1 = np.mean(corr1)
            corr2 = np.mean(corr2)
            if corr1 > corr2:
                better += 1
        if better > 950:
            return 1
        elif better < 50:
            return -1
        else:
            return 0

    def combine_prompt(self):
        """ Take the average of all prompted results for a single prediction.
            We consider encoder-based prompts and decoder-based prompts separately.
        """

        def get_keys(s):
            """ Get the first key and second key in MAP """
            k1, k2 = None, None
            if s.startswith('bart_score_cnn'):
                k1 = 'bart_score_cnn'
            elif s.startswith('bart_score_para'):
                k1 = 'bart_score_para'
            else:
                k1 = 'bart_score'
            if 'src' in s:
                if '_en_' in s:
                    k2 = 'src_hypo_en'
                else:
                    k2 = 'src_hypo_de'
            if 'hypo_ref' in s:
                if '_en_' in s:
                    k2 = 'hypo_ref_en'
                else:
                    k2 = 'hypo_ref_de'
            if 'ref_hypo' in s:
                if '_en_' in s:
                    k2 = 'ref_hypo_en'
                else:
                    k2 = 'ref_hypo_de'
            if 'avg_f' in s:
                if '_en_' in s:
                    k2 = 'avg_f_en'
                else:
                    k2 = 'avg_f_de'
            if 'harm_f' in s:
                if '_en_' in s:
                    k2 = 'harm_f_en'
                else:
                    k2 = 'harm_f_de'
            return k1, k2

        for doc_id in self.data:
            sys_summs = self.data[doc_id]['sys_summs']
            for sys_name in sys_summs:
                types = {
                    'src_hypo_en': [],
                    'src_hypo_de': [],
                    'ref_hypo_en': [],
                    'ref_hypo_de': [],
                    'hypo_ref_en': [],
                    'hypo_ref_de': [],
                    'avg_f_en': [],
                    'avg_f_de': [],
                    'harm_f_en': [],
                    'harm_f_de': []
                }
                MAP = {
                    'bart_score': deepcopy(types),
                    'bart_score_cnn': deepcopy(types),
                    'bart_score_para': deepcopy(types)
                }
                scores = sys_summs[sys_name]['scores']
                for k in scores:
                    if '_en_' in k or '_de_' in k:
                        k1, k2 = get_keys(k)
                        MAP[k1][k2].append(scores[k])
                for k, v in MAP.items():
                    for kk, vv in v.items():
                        if len(vv) == 0:
                            continue
                        new_m = k + '_' + kk
                        if new_m not in self.auto_metrics:
                            print(f'new_metric: {new_m}')
                            self._metrics.append(new_m)
                            self._auto_metrics.append(new_m)
                        self.data[doc_id]['sys_summs'][sys_name]['scores'][new_m] = sum(vv) / len(vv)

    @property
    def auto_metrics(self):
        return self._auto_metrics

    @property
    def metrics(self):
        return self._metrics

    @property
    def human_metrics(self):
        """ All available human metrics. """
        if 'REALSumm' in self.path:
            return ['litepyramid_recall']
        if 'SummEval' in self.path:
            return ['coherence', 'consistency', 'fluency', 'relevance']
        if 'Newsroom' in self.path:
            return ['coherence', 'fluency', 'informativeness', 'relevance']
        if 'Rank19' in self.path or 'QAGS' in self.path:
            return ['fact']


class D2TStat:
    """ A class used to get stats of D2T trained data """

    def __init__(self, path):
        self.path = path
        self.data = read_pickle(path)
        self.sample_id = list(self.data.keys())[0]
        self._metrics = list(self.data[self.sample_id]['scores'].keys())
        self._auto_metrics = [x for x in self.metrics if x not in self.human_metrics]

    def evaluate_text(self, human_metric, auto_metrics=None, table=None):
        print(f'Human metric: {human_metric}')
        headers = ['metric', 'spearman', 'kendalltau']
        metric_with_corr = []
        if auto_metrics is None:
            auto_metrics = self.auto_metrics
        for metric in auto_metrics:
            human_scores = []
            metric_scores = []
            for doc_id in self.data:
                human_scores.append(self.data[doc_id]['scores'][human_metric])
                metric_scores.append(self.data[doc_id]['scores'][metric])
            spearman = spearmanr(human_scores, metric_scores)[0]
            ktau = kendalltau(human_scores, metric_scores)[0]
            metric_with_corr.append([metric, spearman, ktau])
        sorted_metric_with_corr = sorted(metric_with_corr, key=lambda x: x[1], reverse=True)
        if table is not None:
            file = open(table, 'w')
            for each in sorted_metric_with_corr:
                print(f'{each[0]}\t{each[1]}\t{each[2]}', file=file)
            file.flush()
        print(tabulate(sorted_metric_with_corr, headers=headers, tablefmt='simple'))

    def sig_test_two(self, m1, m2, human_metric):
        human_scores = []
        m1_scores = []
        m2_scores = []
        doc_ids = list(self.data.keys())
        better = 0
        random.seed(666)
        for i in trange(1000):
            random.shuffle(doc_ids)
            sub_ids = doc_ids[:int(0.8 * len(doc_ids))]
            for doc_id in sub_ids:
                human_scores.append(self.data[doc_id]['scores'][human_metric])
                m1_scores.append(self.data[doc_id]['scores'][m1])
                m2_scores.append(self.data[doc_id]['scores'][m2])
            spearman1, _ = spearmanr(human_scores, m1_scores)
            spearman2, _ = spearmanr(human_scores, m2_scores)
            if spearman1 > spearman2:
                better += 1
        if better > 950:
            return 1
        elif better < 50:
            return -1
        else:
            return 0

    def combine_prompt(self):
        def get_keys(s):
            """ Get the first key and second key in MAP """
            k1, k2 = None, None
            if s.startswith('bart_score_cnn'):
                k1 = 'bart_score_cnn'
            elif s.startswith('bart_score_para'):
                k1 = 'bart_score_para'
            else:
                k1 = 'bart_score'
            if 'src' in s:
                if '_en_' in s:
                    k2 = 'src_hypo_en'
                else:
                    k2 = 'src_hypo_de'
            if 'hypo_ref' in s:
                if '_en_' in s:
                    k2 = 'hypo_ref_en'
                else:
                    k2 = 'hypo_ref_de'
            if 'ref_hypo' in s:
                if '_en_' in s:
                    k2 = 'ref_hypo_en'
                else:
                    k2 = 'ref_hypo_de'
            if 'avg_f' in s:
                if '_en_' in s:
                    k2 = 'avg_f_en'
                else:
                    k2 = 'avg_f_de'
            if 'harm_f' in s:
                if '_en_' in s:
                    k2 = 'harm_f_en'
                else:
                    k2 = 'harm_f_de'
            return k1, k2

        for doc_id in self.data:
            types = {
                'src_hypo_en': [],
                'src_hypo_de': [],
                'ref_hypo_en': [],
                'ref_hypo_de': [],
                'hypo_ref_en': [],
                'hypo_ref_de': [],
                'avg_f_en': [],
                'avg_f_de': [],
                'harm_f_en': [],
                'harm_f_de': []
            }
            MAP = {
                'bart_score': deepcopy(types),
                'bart_score_cnn': deepcopy(types),
                'bart_score_para': deepcopy(types)
            }
            scores = self.data[doc_id]['scores']
            for k in scores:
                if '_en_' in k or '_de_' in k:
                    k1, k2 = get_keys(k)
                    MAP[k1][k2].append(scores[k])
            for k, v in MAP.items():
                for kk, vv in v.items():
                    if len(vv) == 0:
                        continue
                    new_m = k + '_' + kk
                    if new_m not in self.auto_metrics:
                        print(f'new_metric: {new_m}')
                        self._metrics.append(new_m)
                        self._auto_metrics.append(new_m)
                    self.data[doc_id]['scores'][new_m] = sum(vv) / len(vv)

    def save_data(self, path=None):
        if path is None:
            path = self.path
        save_pickle(self.data, path)

    @property
    def auto_metrics(self):
        return self._auto_metrics

    @property
    def metrics(self):
        return self._metrics

    @property
    def human_metrics(self):
        return ['informativeness', 'naturalness', 'quality']


class WMTStat:
    """ A class used to get stats of WMT trained data """

    def __init__(self, path):
        self.path = path
        self.data = read_pickle(path)
        self._metrics = list(self.data[0]['better']['scores'].keys())
        pos = path.find('-en')
        self.lp = path[pos - 2: pos + 3]
        # systems ranked by their DA score
        self._systems = {
            'de-en': ['Facebook_FAIR.6750', 'RWTH_Aachen_System.6818', 'MSRA.MADL.6910', 'online-B.0', 'JHU.6809',
                      'MLLP-UPV.6899', 'dfki-nmt.6478', 'UCAM.6461', 'online-A.0', 'NEU.6801', 'uedin.6749',
                      'online-Y.0', 'TartuNLP-c.6502', 'online-G.0', 'PROMT_NMT_DE-EN.6683', 'online-X.0'],
            'fi-en': ['MSRA.NAO.6983', 'online-Y.0', 'GTCOM-Primary.6946', 'USYD.6995', 'online-B.0',
                      'Helsinki_NLP.6889', 'online-A.0', 'online-G.0', 'TartuNLP-c.6905', 'online-X.0', 'parfda.6526',
                      'apertium-fin-eng-unconstrained-fien.6449'],
            'gu-en': ['NEU.6756', 'UEDIN.6534', 'GTCOM-Primary.6969', 'CUNI-T2T-transfer-guen.6431',
                      'aylien_mt_gu-en_multilingual.6826', 'NICT.6603', 'online-G.0', 'IITP-MT.6824', 'UdS-DFKI.6861',
                      'IIITH-MT.6688', 'Ju_Saarland.6525'],
            'kk-en': ['online-B.0', 'NEU.6753', 'rug_kken_morfessor.6677', 'online-G.0', 'talp_upc_2019_kken.6657',
                      'NRC-CNRC.6895', 'Frank_s_MT.6127', 'NICT.6770', 'CUNI-T2T-transfer-kken.6436', 'UMD.6736',
                      'DBMS-KU_KKEN.6726'],
            'lt-en': ['GTCOM-Primary.6998', 'tilde-nc-nmt.6881', 'NEU.6759', 'MSRA.MASS.6945', 'tilde-c-nmt.6876',
                      'online-B.0', 'online-A.0', 'TartuNLP-c.6908', 'online-G.0', 'JUMT.6616', 'online-X.0'],
            'ru-en': ['Facebook_FAIR.6937', 'online-G.0', 'eTranslation.6598', 'online-B.0', 'NEU.6803',
                      'MSRA.SCA.6976', 'rerank-re.6540', 'online-Y.0', 'online-A.0', 'afrl-syscomb19.6782',
                      'afrl-ewc.6659', 'TartuNLP-u.6650', 'online-X.0', 'NICT.6561'],
            'zh-en': ['Baidu-system.6940', 'KSAI-system.6927', 'MSRA.MASS.6996', 'MSRA.MASS.6942', 'NEU.6832',
                      'BTRANS.6825', 'online-B.0', 'BTRANS-ensemble.6992', 'UEDIN.6530', 'online-Y.0', 'NICT.6814',
                      'online-A.0', 'online-G.0', 'online-X.0', 'Apprentice-c.6706']
        }

    def save_data(self, path=None):
        if path is None:
            path = self.path
        save_pickle(self.data, path)

    def retrieve_scores(self, metric, doc_ids):
        """ retrieve better, worse scores """
        better, worse = [], []
        for doc_id in doc_ids:
            better.append(float(self.data[doc_id]['better']['scores'][metric]))
            worse.append(float(self.data[doc_id]['worse']['scores'][metric]))
        return better, worse

    def kendall(self, hyp1_scores: list, hyp2_scores: list):
        """ Computes the official WMT19 shared task Kendall correlation score. """
        assert len(hyp1_scores) == len(hyp2_scores)
        conc, disc = 0, 0

        for x1, x2 in zip(hyp1_scores, hyp2_scores):
            if x1 > x2:
                conc += 1
            else:
                disc += 1
        return (conc - disc) / (conc + disc)

    def print_ktau(self, metrics=None):
        headers = ['metric', 'k-tau']
        metric_with_ktau = []
        doc_ids = list(self.data.keys())
        if metrics is None:
            metrics = self.metrics
        for metric in tqdm(metrics):
            better, worse = self.retrieve_scores(metric, doc_ids)
            ktau = self.kendall(better, worse)
            metric_with_ktau.append([metric, ktau])
        sorted_metric_with_ktau = sorted(metric_with_ktau, key=lambda x: x[1], reverse=True)
        print(tabulate(sorted_metric_with_ktau, headers=headers, tablefmt='simple'))

    def print_ref_len(self):
        """ Get the length of reference texts """
        ref_lens = []
        for doc_id in self.data:
            ref = self.data[doc_id]['ref']
            ref_len = len(ref.split(' '))
            ref_lens.append(ref_len)
        print(f'Mean reference length: {np.mean(ref_lens)}')
        print(f'Max reference length: {np.max(ref_lens)}')
        print(f'Min reference length: {np.min(ref_lens)}')
        print(f'20% percentile: {np.percentile(ref_lens, 20)}')
        print(f'80% percentile: {np.percentile(ref_lens, 80)}')
        print(f'90% percentile: {np.percentile(ref_lens, 90)}')

    def print_len_ktau(self, min_len, max_len, metrics=None):
        headers = ['metric', 'k-tau']
        metric_with_ktau = []
        sub_ids = []
        for doc_id in tqdm(self.data):
            ref_len = len(self.data[doc_id]['ref'].split(' '))
            if min_len <= ref_len <= max_len:
                sub_ids.append(doc_id)
        print(f'Considered samples: {len(sub_ids)}')
        if metrics is None:
            metrics = self.metrics
        for metric in tqdm(metrics):
            better, worse = self.retrieve_scores(metric, sub_ids)
            ktau = self.kendall(better, worse)
            metric_with_ktau.append([metric, ktau])
        sorted_metric_with_ktau = sorted(metric_with_ktau, key=lambda x: x[1], reverse=True)
        print(tabulate(sorted_metric_with_ktau, headers=headers, tablefmt='simple'))

    def sig_test_two(self, m1, m2):
        random.seed(666)
        doc_ids = list(self.data.keys())
        better = 0
        for _ in trange(1000):
            random.shuffle(doc_ids)
            sub_ids = doc_ids[:int(0.8 * len(doc_ids))]
            better_m1, worse_m1, better_m2, worse_m2 = [], [], [], []
            for doc_id in sub_ids:
                better_m1.append(float(self.data[doc_id]['better']['scores'][m1]))
                worse_m1.append(float(self.data[doc_id]['worse']['scores'][m1]))
                better_m2.append(float(self.data[doc_id]['better']['scores'][m2]))
                worse_m2.append(float(self.data[doc_id]['worse']['scores'][m2]))
            m1_ktau = self.kendall(better_m1, worse_m1)
            m2_ktau = self.kendall(better_m2, worse_m2)
            if m1_ktau > m2_ktau:
                better += 1

        if better > 950:
            return 1
        elif better < 50:
            return -1
        else:
            return 0

    @property
    def metrics(self):
        return self._metrics

    @property
    def systems(self):
        return self._systems[self.lp]
