import shutil
import os, sys
from subprocess import check_call, check_output
import glob
import argparse
import shutil
import pathlib
import itertools

def call_output(cmd):
    print(f"Executing: {cmd}")
    ret = check_output(cmd, shell=True)
    print(ret)
    return ret

def call(cmd):
    print(cmd)
    check_call(cmd, shell=True)


WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None)

if WORKDIR_ROOT is None or  not WORKDIR_ROOT.strip():
    print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."')
    sys.exit(-1)

SPM_PATH = os.environ.get('SPM_PATH', None)

if SPM_PATH is None or not SPM_PATH.strip():
    print("Please install sentence piecence from https://github.com/google/sentencepiece and set SPM_PATH pointing to the installed spm_encode.py. Exitting...")
    sys.exit(-1)


SPM_MODEL = f'{WORKDIR_ROOT}/sentence.bpe.model'
SPM_VOCAB = f'{WORKDIR_ROOT}/dict_250k.txt'

SPM_ENCODE = f'{SPM_PATH}'

if not os.path.exists(SPM_MODEL):
    call(f"wget https://dl.fbaipublicfiles.com/fairseq/models/mbart50/sentence.bpe.model -O {SPM_MODEL}")


if not os.path.exists(SPM_VOCAB):
    call(f"wget https://dl.fbaipublicfiles.com/fairseq/models/mbart50/dict_250k.txt -O {SPM_VOCAB}")



def get_data_size(raw):
    cmd = f'wc -l {raw}'
    ret = call_output(cmd)
    return int(ret.split()[0])

def encode_spm(model, direction, prefix='', splits=['train', 'test', 'valid'], pairs_per_shard=None):
    src, tgt = direction.split('-')

    for split in splits:
        src_raw, tgt_raw = f'{RAW_DIR}/{split}{prefix}.{direction}.{src}', f'{RAW_DIR}/{split}{prefix}.{direction}.{tgt}'
        if os.path.exists(src_raw) and os.path.exists(tgt_raw):
            cmd = f"""python {SPM_ENCODE} \
            --model {model}\
            --output_format=piece \
            --inputs {src_raw} {tgt_raw}  \
            --outputs {BPE_DIR}/{direction}{prefix}/{split}.bpe.{src} {BPE_DIR}/{direction}{prefix}/{split}.bpe.{tgt} """
            print(cmd)
            call(cmd)


def binarize_(
    bpe_dir,
    databin_dir,
    direction, spm_vocab=SPM_VOCAB, 
    splits=['train', 'test', 'valid'],
):
    src, tgt = direction.split('-')

    try:
        shutil.rmtree(f'{databin_dir}', ignore_errors=True)
        os.mkdir(f'{databin_dir}')
    except OSError as error:
        print(error)
    cmds = [
        "fairseq-preprocess",
        f"--source-lang {src} --target-lang {tgt}",
        f"--destdir {databin_dir}/",
        f"--workers 8",
    ]
    if isinstance(spm_vocab, tuple):
        src_vocab, tgt_vocab = spm_vocab
        cmds.extend(
            [
                f"--srcdict {src_vocab}",
                f"--tgtdict {tgt_vocab}",
            ]
        )
    else:
        cmds.extend(
            [
                f"--joined-dictionary",
                f"--srcdict {spm_vocab}",
            ]
        )
    input_options = []
    if 'train' in splits and glob.glob(f"{bpe_dir}/train.bpe*"):
        input_options.append(
            f"--trainpref {bpe_dir}/train.bpe",
        )        
    if 'valid' in splits and glob.glob(f"{bpe_dir}/valid.bpe*"):
        input_options.append(f"--validpref {bpe_dir}/valid.bpe")
    if 'test' in splits and glob.glob(f"{bpe_dir}/test.bpe*"):
        input_options.append(f"--testpref {bpe_dir}/test.bpe")   
    if len(input_options) > 0:    
        cmd = " ".join(cmds + input_options)
        print(cmd)
        call(cmd)


def binarize(
    databin_dir,
    direction, spm_vocab=SPM_VOCAB, prefix='',
    splits=['train', 'test', 'valid'],
    pairs_per_shard=None,
):
    def move_databin_files(from_folder, to_folder):
        for bin_file in glob.glob(f"{from_folder}/*.bin") \
            +  glob.glob(f"{from_folder}/*.idx") \
            +  glob.glob(f"{from_folder}/dict*"):
            try:
                shutil.move(bin_file, to_folder)
            except OSError as error:
                print(error)      
    bpe_databin_dir = f"{BPE_DIR}/{direction}{prefix}_databin"
    bpe_dir = f"{BPE_DIR}/{direction}{prefix}"
    if pairs_per_shard is None:
        binarize_(bpe_dir, bpe_databin_dir, direction, spm_vocab=spm_vocab, splits=splits)
        move_databin_files(bpe_databin_dir, databin_dir)
    else:
        # binarize valid and test which will not be sharded
        binarize_(
            bpe_dir, bpe_databin_dir, direction,
            spm_vocab=spm_vocab, splits=[s for s in splits if s != "train"])
        for shard_bpe_dir in glob.glob(f"{bpe_dir}/shard*"):
            path_strs = os.path.split(shard_bpe_dir)
            shard_str = path_strs[-1]
            shard_folder = f"{bpe_databin_dir}/{shard_str}"
            databin_shard_folder = f"{databin_dir}/{shard_str}"
            print(f'working from {shard_folder} to {databin_shard_folder}')
            os.makedirs(databin_shard_folder, exist_ok=True)
            binarize_(
                shard_bpe_dir, shard_folder, direction,
                spm_vocab=spm_vocab, splits=["train"])

            for test_data in glob.glob(f"{bpe_databin_dir}/valid.*") + glob.glob(f"{bpe_databin_dir}/test.*"):
                filename = os.path.split(test_data)[-1]
                try:
                    os.symlink(test_data, f"{databin_shard_folder}/{filename}")
                except OSError as error:
                    print(error)                
            move_databin_files(shard_folder, databin_shard_folder)


def load_langs(path):
    with open(path) as fr:
        langs = [l.strip() for l in fr]
    return langs

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_root", default=f"{WORKDIR_ROOT}/ML50")
    parser.add_argument("--raw-folder", default='raw')
    parser.add_argument("--bpe-folder", default='bpe')    
    parser.add_argument("--databin-folder", default='databin')    

    args = parser.parse_args()

    DATA_PATH = args.data_root #'/private/home/yuqtang/public_data/ML50'   
    RAW_DIR = f'{DATA_PATH}/{args.raw_folder}'
    BPE_DIR = f'{DATA_PATH}/{args.bpe_folder}'
    DATABIN_DIR = f'{DATA_PATH}/{args.databin_folder}'
    os.makedirs(BPE_DIR, exist_ok=True)

    raw_files = itertools.chain(
        glob.glob(f'{RAW_DIR}/train*'),
        glob.glob(f'{RAW_DIR}/valid*'),
        glob.glob(f'{RAW_DIR}/test*'),
    )

    directions = [os.path.split(file_path)[-1].split('.')[1] for file_path in raw_files]

    for direction in directions:
        prefix = ""
        splits = ['train', 'valid', 'test']
        try:
            shutil.rmtree(f'{BPE_DIR}/{direction}{prefix}', ignore_errors=True)
            os.mkdir(f'{BPE_DIR}/{direction}{prefix}')
            os.makedirs(DATABIN_DIR, exist_ok=True)
        except OSError as error: 
            print(error)     
        spm_model, spm_vocab = SPM_MODEL, SPM_VOCAB
        encode_spm(spm_model, direction=direction, splits=splits)
        binarize(DATABIN_DIR, direction, spm_vocab=spm_vocab, splits=splits)
