import os
import argparse
import random
import collections
import itertools
import json
import math
from pathos.helpers import mp as multiprocessing

from dsl import constraint
from dsl.types import INT, LIST
from dsl.program import Program, get_used_indices, get_unused_indices
from dsl.impl import ALL_FUNCTIONS, LAMBDAS
from dsl.example import Example
from env.statement import Statement
import params


# For length 1 and 2, the number of possible programs in our DSL is not large and it is likely
# that a larger number will be used as args.num_train. We thus replace it with the known values
# from here when needed.
KNOWN_TRAIN_SIZES = {1: 47, 2: 2883}

def get_free_indices(program, program_len):
    """
    Returns unused indices for the given program
    """
    used = get_used_indices(program)
    total = set(range(program_len + len(program.input_types)))
    return total - used


def get_input_type_combinations(num_inputs):
    """
    Returns all possible input type combinations (list,int) for the given amount of inputs
    """
    input_type_combinations = []
    for num_inputs in range(1, num_inputs + 1):
        # no valid program takes only ints.
        for num_list in range(1, num_inputs + 1):
            input_types = [LIST] * num_list + [INT] * (num_inputs - num_list)
            input_type_combinations.append(input_types)
    return input_type_combinations


def iterate_inputs(function, type_to_vars):
    """
    Yields the cartesian product over all possible parameters to function based on type_to_vars
    """
    if isinstance(function.input_type, tuple):
        input_types = list(function.input_type)
    else:
        input_types = [function.input_type]

    argslists = []
    for input_type in input_types:
        argslists.append(type_to_vars[input_type])
    for args in itertools.product(*argslists):
        yield args


def init_gen_prog_worker(*args):
    global progress_counter, num_programs, program_len
    progress_counter, num_programs, program_len = args


def init_gen_examples_worker(*args):
    global progress_counter, valid_counter, num_programs, num_examples, num_example_tries
    progress_counter, valid_counter, num_programs, num_examples, num_example_tries = args


def gen_program_worker(input_types):
    """
    Generate programs with the given input types.
    Statements are generated by choosing a function randomly, and then sampling parameters so that
    unused variables take precedence. Programs that has unused variables are discarded.
    """
    def helper(functions, program, programs):
        random.shuffle(functions)
        if progress_counter.value >= num_programs:
            return True

        if len(program) >= program_len:
            if get_unused_indices(program) or program in programs:
                return False
            else:
                programs.add(program)
                progress_counter.value += 1
                print("\rGenerating programs... %d\\%d" % (progress_counter.value, num_programs), end="")
                return True

        type_to_vars = collections.defaultdict(list)
        for i, typ in enumerate(program.var_types):
            type_to_vars[typ].insert(0, i)

        # Move free indices to the front
        free_indxs = get_free_indices(program, program_len)
        for typ in program.var_types:
            for var in type_to_vars[typ]:
                if var in free_indxs:
                    type_to_vars[typ].remove(var)
                    type_to_vars[typ].insert(0, var)

        for func in LAMBDAS:
            type_to_vars[func.type].append(func)

        used = set(program.statements)
        for function in functions:
            for args in iterate_inputs(function, type_to_vars):
                if len([arg for arg in args if arg in free_indxs]) == 0:
                    continue
                statement = Statement(function, args)
                if statement in used:
                    continue

                next_program = Program(program.input_types,
                                       program.statements + [statement])
                if helper(functions, next_program, programs):
                    return True

    program_base = Program(input_types, [])
    res = set()
    while progress_counter.value < num_programs:
        helper(ALL_FUNCTIONS, program_base, res)
    return res


def gen_examples_worker(program):
    """
    Generate examples for the given program. Return the examples if successful, or None otherwise.
    """
    print("\rGenerating examples... %d\\%d (remaining programs: %d)" %
          (progress_counter.value, num_programs, valid_counter.value), end="")

    input_output_examples = constraint.get_input_output_examples(program, num_examples=num_examples,
                                                                 num_tries=num_example_tries)

    progress_counter.value += 1
    if input_output_examples:
        return input_output_examples
    else:
        valid_counter.value -= 1
        return None


def write_programs_to_file(f, programs, examples):
    for program in list(programs):
        raw_examples = []
        for inputs, output in examples[program]:
            raw_inputs = [x.val for x in inputs]
            raw_output = output.val
            raw_examples.append((raw_inputs, raw_output))

        program_examples = [dict(inputs=x[0], output=x[1]) for x in raw_examples]
        data = dict(program=program.encoded, examples=program_examples)
        f.write(json.dumps(data) + '\n')


def gen_programs(program_len, num_programs, args):
    """
    Generates the specified amount of programs of the given length. These are the exact steps performed:
    1. Generate <num_programs> programs using gen_program_worker in a process pool
    2. Generate examples for each program by executing gen_examples_worker in a process pool.
       Discard programs for which the required amount of examples could not be generated.
    3. Return a dictionary of the form {program: examples}
    """
    progress_counter = multiprocessing.Value('i', 0)
    gen_prog_pool = multiprocessing.Pool(processes=args.num_workers, initializer=init_gen_prog_worker,
                                         initargs=(progress_counter, num_programs, program_len))

    input_type_combinations = get_input_type_combinations(params.num_inputs)
    programs = gen_prog_pool.map(gen_program_worker, input_type_combinations)
    print('')

    # Flatten
    programs = [item for sublist in programs for item in sublist]
    programs = list(set(programs))

    # Generate examples and filter out null programs
    progress_counter.value = 0
    valid_counter = multiprocessing.Value('i', len(programs))
    gen_examples_pool = multiprocessing.Pool(processes=args.num_workers, initializer=init_gen_examples_worker,
                                             initargs=(progress_counter, valid_counter, len(programs),
                                                       args.num_examples, args.num_example_tries))

    res = gen_examples_pool.map(gen_examples_worker, programs)
    print('')
    examples = dict(zip(programs, res))
    examples = {k: v for k, v in examples.items() if v}
    return examples


def load_cache(path):
    """
    Given a dataset path, loads the programs from it to a form returned by gen_programs(): A dict with
    programs as keys and examples as values
    """
    lines = [json.loads(x) for x in open(path, 'r').readlines()]
    examples = {}
    for i, line in enumerate(lines):
        print("\rLoading program cache... %d\\%d" % (i, len(lines)), end="")
        program = Program.parse(line['program'])
        p_examples = Example.from_line(line)
        p_examples = [(ex.inputs, ex.output) for ex in p_examples]
        examples[program] = p_examples
    print('')

    return examples


def init_discard_identical_worker(*args):
    global existing_programs, progress_counter, new_program_count
    existing_programs, progress_counter, new_program_count = args


def discard_identical_worker(new_examples):
    """
    Given a dictionary of {program: examples}, and a current dataset (given via init_discard_identical_worker),
    this function deletes programs which are equivalent to any program in the current dataset.
    Equivalence is measured by using the examples from new_examples
    """
    new_programs = list(new_examples.keys())
    for i, program in enumerate(new_programs):
        for other in existing_programs:
            if constraint.is_same(program, other, new_examples[program]):
                #print(program, other, new_examples[program])
                del new_examples[program]
                break
        print("\rDiscarding identical programs... %d\\%d" % (progress_counter.value, new_program_count), end="")
        progress_counter.value += 1
    return new_examples

def shuffle_and_divide(orig_filename, split_ratio, train_filename, val_filename):
    '''
    shuffle all the lines in a file and split into train-val
    '''
    lines = open(orig_filename).read().splitlines()
    random.shuffle(lines)
    val_size = math.floor(len(lines)*split_ratio)
    val, train = lines[:val_size], lines[val_size:]
    f = open(train_filename, 'w')
    f.write("\n".join(train) + "\n")
    f.close()
    f = open(val_filename, 'w')
    f.write("\n".join(val) + "\n")
    f.close()

def create_data_for_1(input_filename, output_filename):
    '''
    create PEPS data from a file for GPS data
    '''
    f = open(output_filename, 'w')
    lines = open(input_filename).read().splitlines()
    for l in lines:
        data = json.loads(l.rstrip())
        examples = data['examples']
        program = data['program']
        for e in examples:
            data_mod = dict(program=program, examples=[e])
            f.write(json.dumps(data_mod) + '\n')


def main():
    """
    Generates training programs. These are the basic steps performed:

    D = {}
    for 1 <= i <= max_train_len:
       1. P = Generate programs of length i
       2. E = Generate examples for the generated programs
       3. Discard programs in P that are equivalent to any program in D
       4. D += (P, E)

    """
    parser = argparse.ArgumentParser()

    parser.add_argument('--num_train', type=int, required=True)
    parser.add_argument('--train_output_path', type=str, required=True)
    parser.add_argument('--max_train_len', type=int, required=True)
    parser.add_argument('--num_workers', type=int, default=8)
    parser.add_argument('--num_examples', type=int, default=5)
    parser.add_argument('--val_split_ratio', type=float, default=0.1)
    parser.add_argument('--num_example_tries', type=int, default=200,
                        help='total amount of tries to generate examples to try to generate')
    parser.add_argument('--cache', type=str, default=None,
                        help="Dataset cache from which to continue generating programs")
    args = parser.parse_args()

    if args.cache:
        examples = load_cache(args.cache)
        min_len = max([len(k) for k in examples])
    else:
        examples = {}
        min_len = 0

    for program_len in range(min_len + 1, args.max_train_len + 1):
        num_programs = args.num_train
        if program_len in KNOWN_TRAIN_SIZES:
            num_programs = min(num_programs, KNOWN_TRAIN_SIZES[program_len])

        print("Generating programs of length %d (current dataset size: %d)" % (program_len, len(examples)))
        new_examples = gen_programs(program_len, num_programs, args)

        existing_programs = list(examples.keys())
        counter = multiprocessing.Value('i', 0)
        new_programs = list(new_examples.keys())

        #similarity with already existing training data. May not be same length
        discard_pool = multiprocessing.Pool(processes=args.num_workers, initializer=init_discard_identical_worker,
                                            initargs=(existing_programs, counter, len(new_programs)))
        new_program_parts = [new_programs[i::args.num_workers] for i in range(args.num_workers)]

        new_example_parts = [{p: new_examples[p] for p in programs} for programs in new_program_parts]
        res = discard_pool.map(discard_identical_worker, new_example_parts)
        res = new_example_parts
        print('')
        for d in res:
            examples.update(d)

    train_programs = list(examples.keys())
    print("Finished generation. Total programs: %d" % len(train_programs))

    print('Writing %d trainval programs to %s' % (len(train_programs), args.train_output_path))
    output_dir = '/'.join(args.train_output_path.split("/")[:-1])
    os.makedirs(output_dir, exist_ok=True)
    with open(args.train_output_path, 'w') as f:
        write_programs_to_file(f, train_programs, examples)

    print("Dividing into train and val splits and writing")
    train_path = os.path.join(output_dir, 'train_dataset_gps')
    val_path = os.path.join(output_dir, 'val_dataset_gps')
    shuffle_and_divide(args.train_output_path, args.val_split_ratio, train_path, val_path)

    print("Creating data for training PE model and writing")
    create_data_for_1(train_path, train_path.replace('gps', 'pe'))
    create_data_for_1(val_path, val_path.replace('gps', 'pe'))

if __name__ == '__main__':
    main()

