import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../../'))
import os.path as op
import logging
import json
import random
import shutil
import numpy as np
from tqdm import tqdm
from pathlib import Path
from argparse import ArgumentParser
from typing import List, Tuple
from collections import OrderedDict
from image_synthesis.data.utils.tsv_file import TSVFile, tsv_writer, load_list_file
from image_synthesis.data.utils.tsv_utils import parallel_map, try_delete, concat_files, append_files, write_to_file, ensure_directory
from image_synthesis.data.tsv_dataset import TSVImageTextDataset


def delete_tsv_files(tsvs: List[str]):
    for t in tsvs:
        if op.isfile(t):
            try_delete(t)
        line = op.splitext(t)[0] + '.lineidx'
        if op.isfile(line):
            try_delete(line)
        chunks = op.splitext(t)[0] + '.chunks'
        if op.isfile(chunks):
            try_delete(chunks)


def concat_tsv_files(tsvs: List[str], out_tsv: str):
    concat_files(tsvs, out_tsv)
    sizes = [os.stat(t).st_size for t in tsvs]
    sizes = np.cumsum(sizes)
    all_idx = []
    for i, t in enumerate(tsvs):
        for idx in load_list_file(op.splitext(t)[0] + '.lineidx'):
            if i == 0:
                all_idx.append(idx)
            else:
                all_idx.append(str(int(idx) + sizes[i - 1]))
    write_to_file('\n'.join(all_idx), op.splitext(out_tsv)[0] + '.lineidx')

    chunk_files = [op.splitext(t)[0] + '.chunks' for t in tsvs]
    if all([op.isfile(f) for f in chunk_files]):
        chunk_info = OrderedDict()
        accum_size = 0
        for file in chunk_files:
            chunks = json.load(open(file, 'r'))
            for cls_name, (low, high) in chunks.items():
                chunk_info[cls_name] = (accum_size + low, accum_size + high)
            accum_size += high + 1
        json.dump(chunk_info, open(op.splitext(out_tsv)[0] + '.chunks', 'w'))
    else:
        if any([op.isfile(f) for f in chunk_files]):
            logging.warning("Not all class boundary files exist!")


def append_tsv_files(src_tsv: str, tsv_list: List[str]):
    all_tsvs = [src_tsv] + tsv_list
    sizes = [os.stat(t).st_size for t in all_tsvs]
    sizes = np.cumsum(sizes)
    all_idx = []
    for i, t in enumerate(all_tsvs):
        for idx in load_list_file(op.splitext(t)[0] + '.lineidx'):
            if i == 0:
                all_idx.append(idx)
            else:
                all_idx.append(str(int(idx) + sizes[i - 1]))
    write_to_file('\n'.join(all_idx), op.splitext(src_tsv)[0] + '.lineidx')
    append_files(src_tsv, tsv_list)


def image_search_to_compact_tsv(query: str, search_results: List[Tuple[str, str]], filename: str) -> None:
    def gen_rows():
        for idx, item in enumerate(search_results):
            url, title, criterion = item
            annotation = json.dumps([
                {'class': query},
                {'caption': "%s (%s)" % (title, criterion)}
            ])
            yield "bing_image_search_%05d" % idx, annotation, url
    tsv_writer(gen_rows(), filename)


def tsv_subset_process(info):
    row_processor = info['row_processor']
    idx_process = info['idx_process']
    tmp_out = info['tmp_out']
    idx_range_start = info['idx_range_start']
    idx_range_end = info['idx_range_end']
    head = info['head']
    sep = info['out_sep']
    if op.isfile(tmp_out):
        logging.info('skip since exist: {}'.format(tmp_out))
        return
    if 'in_tsv' in info:
        tsv = info['in_tsv']
    else:
        in_tsv_file = info['in_tsv_file']
        tsv = TSVFile(in_tsv_file)

    def gen_rows():
        if idx_process == 0 and head is not None:
            yield head
        for i in tqdm(range(idx_range_start, idx_range_end)):
            r = row_processor(tsv[i])
            if r is None:
                continue
            yield r

    tsv_writer(gen_rows(), tmp_out, sep=sep)


def parallel_tsv_process(row_processor, in_tsv_file,
                         out_tsv_file, num_process, num_jobs=None, head=None, out_sep='\t'):
    if isinstance(in_tsv_file, str):
        in_tsv = TSVFile(in_tsv_file)
    else:
        in_tsv = in_tsv_file
    total = len(in_tsv)
    if num_jobs is None:
        if num_process == 0:
            num_jobs = 1
        else:
            num_jobs = num_process
    rows_each_rank = (total + num_jobs - 1) // num_jobs
    all_task = []
    if isinstance(in_tsv, TSVFile):
        # we need to clear all the cache in TSVFile. otherwise, the process
        # might need a lot of time to copy the cache in in_tsv, e.g. lineidx
        # when the number of data are huge.
        in_tsv.close()
    for i in range(num_jobs):
        start = i * rows_each_rank
        end = start + rows_each_rank
        end = min(end, total)
        tmp_out = out_tsv_file + '.{}.{}.tsv'.format(i, num_jobs)
        info = {'row_processor': row_processor,
                'idx_process': i,
                # 'in_tsv_file': in_tsv_file,
                'in_tsv': in_tsv,
                'tmp_out': tmp_out,
                'idx_range_start': start,
                'idx_range_end': end,
                'head': head,
                'out_sep': out_sep
                }
        all_task.append(info)
    parallel_map(tsv_subset_process, all_task,
                 num_worker=num_process)
    all_out = [task['tmp_out'] for task in all_task]
    concat_tsv_files(all_out, out_tsv_file)
    delete_tsv_files(all_out)


def scan_class_boundary(filename: str):
    tsv = TSVFile(filename)
    metadata_filename = op.splitext(filename)[0] + '.chunks'

    current_label = ''
    start, end = -1, -1
    class_boundaries = {}
    for idx in tqdm(range(len(tsv))):
        label = tsv[idx][0].split('_')[0]
        if label != current_label:
            if current_label != '':
                class_boundaries.update({current_label: [start, (end if end >= start else start)]})
            current_label = label
            start = idx
        else:
            end = idx
    class_boundaries.update({current_label: [start, len(tsv) - 1]})
    json.dump(class_boundaries, open(metadata_filename, 'w'))

    return class_boundaries


def merge_split_tsv_files(root: str, pattern: str, leave_out_size: int = -1, n_folds: int = -1,
                          remove_temp: bool = True):
    def gen_rows(src):
        for filename, _, image_data in src:
            n_offset = filename.split('_')[0]
            anno = json.dumps([{'class': n_offset}])
            yield filename, anno, image_data

    tsv_files = [str(filename) for filename in Path(root).glob(pattern=pattern)]

    if len(tsv_files) == 0:
        logging.warning("No tsv files found in %s", root)
        return

    tmp_folder = os.path.join(root, 'tmp')
    ensure_directory(tmp_folder)

    val = list()
    for idx in tqdm(range(len(tsv_files))):
        tsv = TSVFile(tsv_files[idx])
        data = [x for x in tsv]
        if leave_out_size > 0:
            length = leave_out_size
        elif n_folds > 0:
            length = len(data) // n_folds
        else:
            raise ValueError()
        random.shuffle(data)

        train = list()
        val.extend(data[:length])
        train.extend(data[length:])

        tsv_writer(gen_rows(train), os.path.join(tmp_folder, '%05d.tsv' % idx))

    print('Concatenating tsv files ...')
    train_files = [str(x) for x in Path(tmp_folder).glob('*.tsv')]
    concat_tsv_files(train_files, os.path.join(root, 'train.tsv'))
    db = TSVFile(os.path.join(root, 'train.tsv'))
    print("Db size:", len(db))

    print('writing tsv ...')
    tsv_writer(gen_rows(val), os.path.join(root, 'val.tsv'))
    db = TSVFile(os.path.join(root, 'val.tsv'))
    print("Db size:", len(db))

    if remove_temp:
        shutil.rmtree(tmp_folder)

def filter_tsv(tsv_path):
    assert len(tsv_path) == 1
    # tsv = TSVFile(tsv_path[0])
    tsv = TSVImageTextDataset(name="conceptualcaption/val", data_root="/mnt/blob/code/dalle/data", image_tsv_file=["gcc-val-image.tsv"], text_tsv_file=["gcc-val-text.tsv"], text_format="json", indices_list_file="/mnt/blob/code/dalle/data/filtered_conceptual_caption_val_index_min27000_max27000.txt")
    print("finished")
    print("over")


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--func', type=str, default='concat_tsv_files')
    parser.add_argument('--tsvs', type=str, default='')
    parser.add_argument('--tsv_format', default='istockphoto-train-image-low_res-%03d.tsv', type=str)
    parser.add_argument('--start_idx', type=int, default=0)
    parser.add_argument('--end_idx', type=int, default=0)
    parser.add_argument('--out_tsv', type=str, default='')
    args = parser.parse_args()
    if args.tsvs!='':
        args.tsvs = args.tsvs.split(',')
    else:
        args.tsvs = [args.tsv_format % file_idx for file_idx in range(args.start_idx, args.end_idx+1)]
    print('total %d tsv files to merge' % len(args.tsvs))
    if args.func == 'filter_tsv':
        output = filter_tsv(args.tsvs)

    eval(args.func)(args.tsvs, args.out_tsv)
