from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import torch
from pathos.helpers import mp as multiprocessing
from copy import deepcopy
import numpy as np

import params
from model.model import PCCoder, Generator
from scripts.gen_programs import init_discard_identical_worker, discard_identical_worker, gen_programs, KNOWN_TRAIN_SIZES
from scripts.solve_problems import solve_problems

num_episodes = 7
pop_size = 25
top_limit = 5
timeout = 60


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('output_path', type=str, help='Output path of trained model')
    parser.add_argument('--model_path', default='', type=str, help='Path of previous model to load')
    parser.add_argument('--num_workers', type=int, default=16)
    parser.add_argument('--num_examples', type=int, default=params.num_examples)
    parser.add_argument('--num_example_tries', type=int, default=500,
                        help='total amount of tries to generate examples to try to generate')
    args = parser.parse_args()
    train_adv(args)


def train_adv(args):
    discriminator = PCCoder()
    if args.model_path:
        discriminator.load(args.model_path)
    discriminator.eval()
    generators = [Generator() for _ in range(pop_size)]

    best_generator = None
    best_reward = 0
    with torch.no_grad():
        for episode in range(num_episodes):
            print("Episode %d" % episode)

            rewards = []
            for generator in generators:
                print(generator.integer_min, generator.integer_max, generator.list_len_min, generator.list_len_max)
                reward = []
                prev_examples = {}
                for program_len in [1,2,3,4,5]:
                    examples = {}
                    num_programs = 500
                    if program_len in KNOWN_TRAIN_SIZES:
                        num_programs = min(num_programs, KNOWN_TRAIN_SIZES[program_len] - 5)
                    new_examples, new_incomplete_examples = gen_programs(program_len, num_programs, args, generator if program_len == 5 else None)

                    existing_programs = list(prev_examples.keys())
                    counter = multiprocessing.Value('i', 0)
                    new_programs = list(new_examples.keys())
                    discard_pool = multiprocessing.Pool(processes=args.num_workers, initializer=init_discard_identical_worker,
                                                        initargs=(existing_programs, counter, len(new_programs)))
                    new_program_parts = [new_programs[i::args.num_workers] for i in range(args.num_workers)]

                    new_example_parts = [{p: new_examples[p] for p in programs} for programs in new_program_parts]
                    res = discard_pool.map(discard_identical_worker, new_example_parts)
                    discard_pool.close()
                    print('')
                    for d in res:
                        examples.update(d)
                        prev_examples.update(d)
                    if program_len < 5:
                        continue
                    problems = []
                    for program in examples:
                        problems.append({'program': program.encoded, 'examples': [{'inputs': [x.val for x in inp], 'output': out.val} for inp, out in examples[program]]})
                    print(len(problems))
                    problems = problems[:100]

                    if len(problems) >= 25:
                        print(problems[0])
                        solutions = solve_problems(problems, 'beam', discriminator, timeout, program_len, 819200, 32)
                        sub_reward = np.mean([solution['time'] >= timeout for solution in solutions])
                        reward.extend([solution['time'] >= timeout for solution in solutions])
                        print(sub_reward)

                print('')
                print(np.mean(reward))
                print('')
                rewards.append(np.mean(reward) if reward else 0)
            sorted_parent_indexes = np.argsort(rewards)[::-1][:top_limit]
            new_best_reward = rewards[sorted_parent_indexes[0]]
            if new_best_reward > best_reward:
                best_generator = generators[sorted_parent_indexes[0]]
                best_reward = new_best_reward
            children = []
            for idx in sorted_parent_indexes:
                child = deepcopy(generators[idx])
                children.append(child)
                for _ in range(pop_size // top_limit - 1):
                    child = deepcopy(generators[idx])
                    child.mutate(episode)
                    children.append(child)
            generators = children

            print()
            print("Avg Reward: %f" % np.mean(rewards))
            print("Best Reward: %f" % max(rewards))
            print()

        best_generator.save(args.output_path + ".%d" % episode)

if __name__ == '__main__':
    main()
