import pickle
import jsonlines
import nltk
from nltk.tokenize import sent_tokenize
from nltk import word_tokenize
import numpy as np
from tabulate import tabulate
from mosestokenizer import *
import random
from random import choices
import os
import sys
import re
import json
from collections import defaultdict as ddict
from scipy.stats import pearsonr, spearmanr, kendalltau
import logging

# nltk.download('stopwords')
detokenizer = MosesDetokenizer('en')

def save_json(file_name, data):
    with open(file_name, 'w') as f:
        json.dump(data, f, indent=4, ensure_ascii=False)

def load_json(file_name):
    with open(file_name, 'r') as f:
        return json.load(f)

def read_pickle(file):
    with open(file, 'rb') as f:
        data = pickle.load(f)
    return data


def save_pickle(file,data):
    with open(file, 'wb') as f:
        pickle.dump(data, f)
    print(f'Saved to {file}.')


def read_file_to_list(file_name):
    lines = []
    with open(file_name, 'r', encoding='utf8') as f:
        for line in f.readlines():
            lines.append(line.strip())
    return lines


def write_list_to_file(list_to_write, filename):
    out_file = open(filename, 'w')
    for line in list_to_write:
        print(line, file=out_file)
    out_file.flush()
    out_file.close()
    print(f'Saved to {filename}.')


def read_jsonlines_to_list(file_name):
    lines = []
    with jsonlines.open(file_name, 'r') as reader:
        for obj in reader:
            lines.append(obj)
    return lines


def write_list_to_jsonline(list_to_write, filename):
    with jsonlines.open(filename, 'w') as writer:
        writer.write_all(list_to_write)
    print(f'Saved to {filename}.')


def capitalize_sents(text: str):
    """ Given a string, capitalize the initial letter of each sentence. """
    sentences = sent_tokenize(text)
    sentences = [sent.strip() for sent in sentences]
    sentences = [sent.capitalize() for sent in sentences]
    sentences = " ".join(sentences)
    return sentences


def is_capitalized(text: str):
    """ Given a string (system output etc.) , check whether it is lowercased,
        or normally capitalized.
    """
    return not text.islower()


def tokenize(text: str):
    words = word_tokenize(text)
    return " ".join(words)


def detokenize(text: str):
    words = text.split(" ")
    return detokenizer(words)


def use_original_bracket(text: str):
    return text.replace('-lrb-', '(').replace('-rrb-', ')').replace('-LRB-', '(').replace('-RRB-', ')').replace('-lsb-',
                                                                                                                '[').replace(
        '-rsb-', ']').replace('-LSB-', '[').replace('-RSB-', ']')


# Disable print
def blockPrint():
    sys.stdout = open(os.devnull, 'w')


# Restore print
def enablePrint():
    sys.stdout = sys.__stdout__


def retrieve_scores(saveto):
    def get_r_p_f(line):
        line = line.split(" ")
        r = float(line[-3][-7:])
        p = float(line[-2][-7:])
        f = float(line[-1][-7:])
        return r, p, f

    lines = read_file_to_list(saveto)
    rouge1_list, rouge2_list, rougel_list = [], [], []
    for line in lines:
        if line.startswith('1 ROUGE-1 Eval'):
            rouge1_list.append(get_r_p_f(line))
        if line.startswith('1 ROUGE-2 Eval'):
            rouge2_list.append(get_r_p_f(line))
        if line.startswith('1 ROUGE-L Eval'):
            rougel_list.append(get_r_p_f(line))
    return rouge1_list, rouge2_list, rougel_list


def get_rank(data, metric):
    """ Rank all systems based on a metric (avg score) """
    scores = {}  # {sysname: [scores]}
    for doc_id in data:
        sys_summs = data[doc_id]['sys_summs']
        for sys_name in sys_summs:
            score = sys_summs[sys_name]['scores'][metric]
            scores.setdefault(sys_name, []).append(score)
    scores = {k: sum(v) / len(v) for k, v in scores.items()}
    new_scores = dict(sorted(scores.items(), key=lambda x: x[1], reverse=True))
    # for k in new_scores:
    #     print(k)
    return new_scores.keys()


def get_sents_from_tags(text, sent_start_tag, sent_end_tag):
    sents = re.findall(r'%s (.+?) %s' % (sent_start_tag, sent_end_tag), text)
    sents = [sent for sent in sents if len(sent) > 0]  # remove empty sents
    return sents


def get_metrics_list(sd):
    """
    Does each system summary dict have same all_metrics?
    :param sd: scores dict
    :return: list of all_metrics in the scores dict
    """
    metrics_tuple_set = set(
        [tuple(sorted(list(x['scores'].keys())))
         for d in sd.values() for x in d['sys_summs'].values()])
    assert len(metrics_tuple_set) == 1, (
        metrics_tuple_set, "all system summary score dicts should have the same set of all_metrics")
    metrics_list = list(list(metrics_tuple_set)[0])
    return metrics_list


def print_score_ranges(sd):
    metrics_list = get_metrics_list(sd)
    print_list = []
    headers = ["min", "25-perc", "median", "75-perc", "max", "mean"]
    for m in metrics_list:
        scores = [s['scores'][m] for d in sd.values() for s in d['sys_summs'].values() if s['sys_summ'] != 'EMPTY']
        print_list.append([m,
                           np.min(scores),
                           np.percentile(scores, 25),
                           np.median(scores),
                           np.percentile(scores, 75),
                           np.max(scores),
                           np.mean(scores)])
    print(tabulate(print_list, headers=headers, floatfmt=".6f", tablefmt="simple"))


def get_system_level_scores(sd, metrics, agg='mean', nas=False):
    """
    systems[system_name][metric] = average_score or list of scores
    """
    systems = ddict(lambda: ddict(list))

    for isd in sd.values():
        for system_name, scores in isd['sys_summs'].items():
            for m in metrics:
                # Empty summary
                if scores['sys_summ'] == 'EMPTY':
                    systems[system_name][m].append(None)
                else:
                    systems[system_name][m].append(scores['scores'][m])

    for system_name, scores in systems.items():
        for m in scores:
            all_scores = systems[system_name][m]
            if agg == 'mean':
                all_scores = [x for x in all_scores if x is not None]
                systems[system_name][m] = np.mean(all_scores)

    if nas:
        min_scores = {}
        max_scores = {}
        for m in metrics:
            min_scores[m] = np.min([systems[sys][m] for sys in systems.keys()])
            max_scores[m] = np.max([systems[sys][m] for sys in systems.keys()])
        for sys in systems:
            systems[sys]['nas'] = np.mean([
                (systems[sys][m] - min_scores[m]) / (max_scores[m] - min_scores[m]) for m in metrics
            ])

    return systems


def get_topk(systems, k, metric='rouge2_f'):
    systems_l = [(name, score[metric]) for name, score in systems.items()]
    systems_l = sorted(systems_l, key=lambda x: x[1], reverse=True)
    topk_system_names = [tup[0] for tup in systems_l[:k]]
    return {name: systems[name] for name in topk_system_names}


def print_correlation(topk_systems, metric_pairs):
    # disagreement between every pair of metrics for the topk
    headers = ['metric_pair', 'pearson', 'spearman', 'kendalltau']
    print_list = []
    for pair in metric_pairs:
        if 'bart_en_sim' in pair[1] or 'bart_sim' in pair[1]:
            continue
        m1_scores = []
        m2_scores = []
        for scores in topk_systems.values():
            m1_scores.append(scores[pair[0]])  # Human metric
            m2_scores.append(scores[pair[1]])

        pearson, _ = pearsonr(m1_scores, m2_scores)
        spearman, _ = spearmanr(m1_scores, m2_scores)
        ktau, _ = kendalltau(m1_scores, m2_scores)
        print_list.append([f'{pair[1]}', pearson, spearman, ktau])
    # sort based on pearson
    print_list = sorted(print_list, key=lambda x: x[2], reverse=True)
    print(tabulate(print_list, headers=headers, tablefmt='simple'))


def get_predictions_br(system_pairs, systems, metric):
    random.seed(666)
    preds = {}
    for pair in system_pairs:
        sys1 = systems[pair[0]][metric]
        sys2 = systems[pair[1]][metric]
        n = len(sys1)
        points = [i for i in range(0, n)]
        is_better = 0
        N = 1000
        for i in range(N):
            sample = choices(points, k=n)
            sys1_, sys2_ = [], []
            # Due to EMPTY summary, we have to ensure sys1_, sys2_ not empty
            while len(sys1_) == 0:
                for p in sample:
                    if sys1[p] is None or sys2[p] is None:
                        continue
                    else:
                        sys1_.append(sys1[p])
                        sys2_.append(sys2[p])
                sample = choices(points, k=n)
            if np.mean(sys1_) > np.mean(sys2_):
                is_better += 1

        if is_better / N >= 0.95:
            preds[pair] = 0  # pair[0] is better
        elif is_better / N <= 0.05:
            preds[pair] = 1  # pair[1] is better
        else:
            preds[pair] = 2  # can't say
    return preds

def reformat_sum_for_bartscore(data_path,our_score_name):
    logging.info(f"Reformatting {data_path} to bartscore format")
    data = load_json(data_path)
    new_data = {}

    for item in data:
        id = item["id"]
        input = item["input"]
        output = item["output"]
        candidates = item["candidates"]
        new_data_item = {
            "src": input,
            "ref_summ": output,
            "sys_summs": {}
        }
        good = True
        for cand in candidates:
            model = cand["model"]
            text = cand["text"]
            scores = cand["scores"]
            if our_score_name not in scores:
                print("xgptscore")
                good = False
                break
            new_data_item["sys_summs"][model] = {
                "sys_summ": text,
                "scores": scores
            }
        if good:
            new_data[id] = new_data_item
    
    logging.info("Total good data: {}/{}".format(len(new_data), len(data)))
    new_data_path = data_path.with_name(data_path.stem + "_bs_format.pkl")
    save_pickle(new_data_path, new_data)
    save_json(new_data_path.with_suffix(".json"), new_data)
    logging.info("Saved data in bartscore format to {}".format(new_data_path))
    
def reformat_d2t_for_bartscore(data_path, our_score_name):
    logging.info(f"Reformatting {data_path} to bartscore format")
    data = load_json(data_path)
    new_data = {}

    for item in data:
        id = item["id"]
        input = item["input"]
        output = item["output"]
        candidates = item["candidates"]
        if our_score_name not in candidates[0]["scores"] or candidates[0]["scores"][our_score_name] is None:
            continue
        scores = candidates[0]["scores"]
        new_data[id] = {
            "src": input,
            "ref_summs": output,
            "sys_summ": candidates[0]["text"],
            "scores": scores
        }

    logging.info("Total good data: {}/{}".format(len(new_data), len(data)))
    new_data_path = data_path.with_name(data_path.stem + "_bs_format.pkl")
    save_pickle(new_data_path, new_data)
    logging.info("Saved data in bartscore format to {}".format(new_data_path))
    