# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import os.path as op
import re
from tabulate import tabulate
from collections import Counter


def comp_purity(p_xy, axis):
    max_p = p_xy.max(axis=axis)
    marg_p = p_xy.sum(axis=axis)
    indv_pur = max_p / marg_p
    aggr_pur = max_p.sum()
    return indv_pur, aggr_pur


def comp_entropy(p):
    return (-p * np.log(p + 1e-8)).sum()


def comp_norm_mutual_info(p_xy):
    p_x = p_xy.sum(axis=1, keepdims=True)
    p_y = p_xy.sum(axis=0, keepdims=True)
    pmi = np.log(p_xy / np.matmul(p_x, p_y) + 1e-8)
    mi = (p_xy * pmi).sum()
    h_x = comp_entropy(p_x)
    h_y = comp_entropy(p_y)
    return mi, mi / h_x, mi / h_y, h_x, h_y


def pad(labs, n):
    if n == 0:
        return np.array(labs)
    return np.concatenate([[labs[0]] * n, labs, [labs[-1]] * n])


def comp_avg_seg_dur(labs_list):
    n_frms = 0
    n_segs = 0
    for labs in labs_list:
        labs = np.array(labs)
        edges = np.zeros(len(labs)).astype(bool)
        edges[0] = True
        edges[1:] = labs[1:] != labs[:-1]
        n_frms += len(edges)
        n_segs += edges.astype(int).sum()
    return n_frms / n_segs


def comp_joint_prob(uid2refs, uid2hyps):
    """
    Args:
        pad: padding for spliced-feature derived labels
    """
    cnts = Counter()
    skipped = []
    abs_frmdiff = 0
    for uid in uid2refs:
        if uid not in uid2hyps:
            skipped.append(uid)
            continue
        refs = uid2refs[uid]
        hyps = uid2hyps[uid]
        abs_frmdiff += abs(len(refs) - len(hyps))
        min_len = min(len(refs), len(hyps))
        refs = refs[:min_len]
        hyps = hyps[:min_len]
        cnts.update(zip(refs, hyps))
    tot = sum(cnts.values())

    ref_set = sorted({ref for ref, _ in cnts.keys()})
    hyp_set = sorted({hyp for _, hyp in cnts.keys()})
    ref2pid = dict(zip(ref_set, range(len(ref_set))))
    hyp2lid = dict(zip(hyp_set, range(len(hyp_set))))
    # print(hyp_set)
    p_xy = np.zeros((len(ref2pid), len(hyp2lid)), dtype=float)
    for (ref, hyp), cnt in cnts.items():
        p_xy[ref2pid[ref], hyp2lid[hyp]] = cnt
    p_xy /= p_xy.sum()
    return p_xy, ref2pid, hyp2lid, tot, abs_frmdiff, skipped


def read_phn(tsv_path, rm_stress=True):
    uid2phns = {}
    with open(tsv_path) as f:
        for line in f:
            uid, phns = line.rstrip().split("\t")
            phns = phns.split(",")
            if rm_stress:
                phns = [re.sub("[0-9]", "", phn) for phn in phns]
            uid2phns[uid] = phns
    return uid2phns


def read_lab(tsv_path, lab_path, pad_len=0, upsample=1):
    """
    tsv is needed to retrieve the uids for the labels
    """
    with open(tsv_path) as f:
        f.readline()
        uids = [op.splitext(op.basename(line.rstrip().split()[0]))[0] for line in f]
    with open(lab_path) as f:
        labs_list = [pad(line.rstrip().split(), pad_len).repeat(upsample) for line in f]
    assert len(uids) == len(labs_list)
    return dict(zip(uids, labs_list))


def main_lab_lab(
    tsv_dir,
    lab_dir,
    lab_name,
    lab_sets,
    ref_dir,
    ref_name,
    pad_len=0,
    upsample=1,
    verbose=False,
):
    # assume tsv_dir is the same for both the reference and the hypotheses
    tsv_dir = lab_dir if tsv_dir is None else tsv_dir

    uid2refs = {}
    for s in lab_sets:
        uid2refs.update(read_lab(f"{tsv_dir}/{s}.tsv", f"{ref_dir}/{s}.{ref_name}"))

    uid2hyps = {}
    for s in lab_sets:
        uid2hyps.update(
            read_lab(
                f"{tsv_dir}/{s}.tsv", f"{lab_dir}/{s}.{lab_name}", pad_len, upsample
            )
        )
    _main(uid2refs, uid2hyps, verbose)


def main_phn_lab(
    tsv_dir,
    lab_dir,
    lab_name,
    lab_sets,
    phn_dir,
    phn_sets,
    pad_len=0,
    upsample=1,
    verbose=False,
):
    uid2refs = {}
    for s in phn_sets:
        uid2refs.update(read_phn(f"{phn_dir}/{s}.tsv"))

    uid2hyps = {}
    tsv_dir = lab_dir if tsv_dir is None else tsv_dir
    for s in lab_sets:
        uid2hyps.update(
            read_lab(
                f"{tsv_dir}/{s}.tsv", f"{lab_dir}/{s}.{lab_name}", pad_len, upsample
            )
        )
    _main(uid2refs, uid2hyps, verbose)


def _main(uid2refs, uid2hyps, verbose):
    (p_xy, ref2pid, hyp2lid, tot, frmdiff, skipped) = comp_joint_prob(
        uid2refs, uid2hyps
    )
    ref_pur_by_hyp, ref_pur = comp_purity(p_xy, axis=0)
    hyp_pur_by_ref, hyp_pur = comp_purity(p_xy, axis=1)
    (mi, mi_norm_by_ref, mi_norm_by_hyp, h_ref, h_hyp) = comp_norm_mutual_info(p_xy)
    outputs = {
        "ref pur": ref_pur,
        "hyp pur": hyp_pur,
        "H(ref)": h_ref,
        "H(hyp)": h_hyp,
        "MI": mi,
        "MI/H(ref)": mi_norm_by_ref,
        "ref segL": comp_avg_seg_dur(uid2refs.values()),
        "hyp segL": comp_avg_seg_dur(uid2hyps.values()),
        "p_xy shape": p_xy.shape,
        "frm tot": tot,
        "frm diff": frmdiff,
        "utt tot": len(uid2refs),
        "utt miss": len(skipped),
    }
    print(tabulate([outputs.values()], outputs.keys(), floatfmt=".4f"))


if __name__ == "__main__":
    """
    compute quality of labels with respect to phone or another labels if set
    """
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("tsv_dir")
    parser.add_argument("lab_dir")
    parser.add_argument("lab_name")
    parser.add_argument("--lab_sets", default=["valid"], type=str, nargs="+")
    parser.add_argument(
        "--phn_dir",
        default="/checkpoint/wnhsu/data/librispeech/960h/fa/raw_phn/phone_frame_align_v1",
    )
    parser.add_argument(
        "--phn_sets", default=["dev-clean", "dev-other"], type=str, nargs="+"
    )
    parser.add_argument("--pad_len", default=0, type=int, help="padding for hypotheses")
    parser.add_argument(
        "--upsample", default=1, type=int, help="upsample factor for hypotheses"
    )
    parser.add_argument("--ref_lab_dir", default="")
    parser.add_argument("--ref_lab_name", default="")
    parser.add_argument("--verbose", action="store_true")
    args = parser.parse_args()

    if args.ref_lab_dir and args.ref_lab_name:
        main_lab_lab(
            args.tsv_dir,
            args.lab_dir,
            args.lab_name,
            args.lab_sets,
            args.ref_lab_dir,
            args.ref_lab_name,
            args.pad_len,
            args.upsample,
            args.verbose,
        )
    else:
        main_phn_lab(
            args.tsv_dir,
            args.lab_dir,
            args.lab_name,
            args.lab_sets,
            args.phn_dir,
            args.phn_sets,
            args.pad_len,
            args.upsample,
            args.verbose,
        )
