# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import os
import argparse
import pandas as pd
import sys


WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None)

if WORKDIR_ROOT is None or  not WORKDIR_ROOT.strip():
    print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."')
    sys.exit(-1)

def load_langs(path):
    with open(path) as fr:
        langs = [l.strip() for l in fr]
    return langs



def load_sentences(raw_data, split, direction):
    src, tgt = direction.split('-')
    src_path = f"{raw_data}/{split}.{direction}.{src}"
    tgt_path = f"{raw_data}/{split}.{direction}.{tgt}"
    if os.path.exists(src_path) and os.path.exists(tgt_path):
        return [(src, open(src_path).read().splitlines()), (tgt, open(tgt_path).read().splitlines())]
    else:
        return []

def swap_direction(d):
    src, tgt = d.split('-')
    return f'{tgt}-{src}'

def get_all_test_data(raw_data, directions, split='test'):
    test_data = [ 
        x
        for dd in directions
        for d in [dd, swap_direction(dd)]
        for x in load_sentences(raw_data, split, d)
    ]
    # all_test_data = {s for _, d in test_data for s in d}
    all_test_data = {}
    for lang, d in test_data:
        for s in d:
            s = s.strip()
            lgs = all_test_data.get(s, set())
            lgs.add(lang)
            all_test_data[s] = lgs
    return all_test_data, test_data


def check_train_sentences(src_path, tgt_path, direction, all_test_data, mess_up_train={}):
    # src, tgt = direction.split('-')
    print(f'check training data for {direction} in {src_path} and {tgt_path}')
    size = 0
    overlapped_size_counted_dup = 0
    if not os.path.exists(tgt_path) or not os.path.exists(src_path):
        return mess_up_train, size, overlapped_size_counted_dup

    with open(src_path) as f, open(tgt_path) as g:
        for src_line, tgt_line in zip(f, g):
            s = src_line.strip()
            t = tgt_line.strip()
            size += 1
            if  s in all_test_data:
                langs = mess_up_train.get(s, set())
                langs.add(direction)
                mess_up_train[s] = langs
                overlapped_size_counted_dup += 1
            if t in all_test_data:
                langs = mess_up_train.get(t, set())
                langs.add(direction)
                mess_up_train[t] = langs 
                overlapped_size_counted_dup += 1
    print(f'{direction}: size={size}, overlapped={overlapped_size_counted_dup}')
    return mess_up_train, size, overlapped_size_counted_dup

def check_train_all(raw_data, directions, all_test_data):
    mess_up_train = {}
    data_sizes = {}
    # raw_data = '~chau/data-bin/MineBART/multilingual_mined_100M/en_XX/et_EE-en_XX/all.{en_XX, et_EE}'
    print(f'checking training data againsts # {len(all_test_data)} sentences')
    print(f'example test data: ', [s for i, s in enumerate(all_test_data.keys()) if i < 10])
    for direction in directions:
        src, tgt = direction.split('-')
        path = f'{raw_data}/en_XX/{direction}/all'
        src_path = f'{path}.{src}'
        tgt_path = f'{path}.{tgt}'
        print(f'checking {src_path} {tgt_path}')
        _, size, overlapped_size_counted_dup = check_train_sentences(src_path, tgt_path, direction, all_test_data, mess_up_train)
        data_sizes[direction] = (size, overlapped_size_counted_dup)
    return mess_up_train, data_sizes




def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--folder", type=str, required=True,
                        help="the data folder ")
    parser.add_argument("--test-data", type=str, required=True,
                        help="the test data folder ")                        
    parser.add_argument('--directions', type=str, default=None, required=False)

    args = parser.parse_args()    
    directions = args.directions.split(',')
    directions = sorted(set(directions))

    results = []
    # print(f'checking where {args.split} split data are in training')
    # print(f'direction\tcommon_count\tsrc common\ttgt common\tfrom_size\tto_size')
    raw_data = args.folder
    all_test_data, test_data = get_all_test_data(args.test_data, directions, split='test')
    mess_up_train, data_sizes = check_train_all(raw_data, directions, all_test_data)
    print(data_sizes)


if __name__ == "__main__":
    main()
