import chardet
import magic
import os
import re
import glob
import subprocess
import tempfile
# import pylatexpand

MAIN_TEX_PATT = re.compile(r'(\\begin\s*\{\s*document\s*\})', re.I)
# ^ with capturing parentheses so that the pattern can be used for splitting
PDF_EXT_PATT = re.compile(r'^\.pdf$', re.I)
GZ_EXT_PATT = re.compile(r'^\.gz$', re.I)
TEX_EXT_PATT = re.compile(r'^\.tex$', re.I)
NON_TEXT_PATT = re.compile(r'^\.(pdf|eps|jpg|png|gif)$', re.I)
BBL_SIGN = '\\bibitem'
# natbib fix
PRE_FIX_NATBIB = True
NATBIB_PATT = re.compile((r'\\cite(t|p|alt|alp|author|year|yearpar)\s*?\*?\s*?'
                           '(\[[^\]]*?\]\s*?)*?\s*?\*?\s*?\{([^\}]+?)\}'),
                         re.I)
# bibitem option fix
PRE_FIX_BIBOPT = True
BIBOPT_PATT = re.compile(r'\\bibitem\s*?\[[^]]*?\]', re.I|re.M)

def read_file(path):
    try:
        with open(path) as f:
            cntnt = f.read()
    except UnicodeDecodeError:
        blob = open(path, 'rb').read()
        m = magic.Magic(mime_encoding=True)
        encoding = m.from_buffer(blob)
        try:
            cntnt = blob.decode(encoding)
        except (UnicodeDecodeError, LookupError) as e:
            encoding = chardet.detect(blob)['encoding']
            if encoding:
                try:
                    cntnt = blob.decode(encoding, errors='replace')
                except:
                    return ''
            else:
                return ''
    return cntnt


def remove_math(latex_str):
    parts = re.split(MAIN_TEX_PATT, latex_str, maxsplit=1)
    # for patt in FILTER_PATTS:
    #      parts[2] = re.sub(patt, '', parts[2])
    return ''.join(parts)


def normalize(path, out_dir, write_logs=True):
    """
    Normalize an arXiv file
    Adapted from https://github.com/IllDepence/unarXive
        with modifications

    Identifies the primary *.tex file, the bibliography file,
    and expands other tex files and the bibliography into the
    main tex file
    """
    if os.path.exists(out_dir)==False:
            os.mkdir(out_dir)
    def log(msg):
        if write_logs:
            with open(os.path.join(out_dir, 'log.txt'), 'a') as f:
                f.write('{}\n'.format(msg))

    # break path
    _, fn = os.path.split(path.strip('/'))

    # identify main tex file
    main_tex_path = None
    ignored_names = []

    # check .tex files first
    for tfn in os.listdir(path):

        if not TEX_EXT_PATT.match(os.path.splitext(tfn)[1]):
            ignored_names.append(tfn)
            continue

        try:
            cntnt = read_file(os.path.join(path, tfn))
        except:
            continue

        if re.search(MAIN_TEX_PATT, cntnt) is not None:
            main_tex_path = tfn
   

    # try other files
    if main_tex_path is None:
        for tfn in ignored_names:
            if NON_TEXT_PATT.match(os.path.splitext(tfn)[1]):
                continue
            try:
                cntnt = read_file(os.path.join(path, tfn))
                if re.search(MAIN_TEX_PATT, cntnt) is not None:
                    main_tex_path = tfn
            except:
                continue
    # give up
    if main_tex_path is None:
        log(('couldn\'t find main tex file in dump archive {}'
             '').format(fn))
    
    # flatten to single tex file and save
    with tempfile.TemporaryDirectory() as tmp_dir_path:
        temp_tex_fn = os.path.join(tmp_dir_path, f'{fn}.tex')
        # find bbl file
        #print(path)
        main_tex_fn = os.path.join(path, main_tex_path)
        bbl_files = glob.glob(os.path.join(path, '*.bbl'))

        if bbl_files:
            latexpand_args = ['',
                              '--expand-bbl',
                              os.path.split(bbl_files[0])[1],
                              main_tex_path,
                              '--output',
                              temp_tex_fn]
        else:
            latexpand_args = ['',
                              main_tex_path,
                              '--output',
                              temp_tex_fn]

        # run latexpand
        with open(os.path.join(out_dir, 'log_latexpand.txt'), 'a+') as err:
            subprocess.run(latexpand_args, stderr=err, cwd=path)
            

        # re-read and write to ensure utf-8 b/c latexpand doesn't
        # behave
        if os.path.exists(out_dir+'/'+fn)==False:
            os.mkdir(out_dir+'/'+fn)
        new_tex_fn = os.path.join(out_dir,fn, f'{fn}.tex')
        cntnt = read_file(temp_tex_fn)
        #print(cntnt)
        # all_files = re.findall(r'\\input\{(.+)\}',cntnt)
        # for fname in all_files:
        #     print(fname)
        #     section_data = read_file('axcell_papers/'+'', fname)
        #     cntnt = cntnt.replace(f'\\input{{{fname}}}', section_data)  
        cntnt = NATBIB_PATT.sub(r'\\cite{\3}', cntnt)
        cntnt = BIBOPT_PATT.sub(r'\\bibitem', cntnt)
        cntnt = remove_math(cntnt)
          
        with open(new_tex_fn, mode='w', encoding='utf-8') as f:
            f.write(cntnt)
            