import random
from copy import deepcopy
import numpy as np
from karel.world import generateExamples
from nps.evaluate import evaluate_model

num_episodes = 2

class Generator:
    def __init__(self):
        self.rows = random.randint(2, 16)
        self.cols = random.randint(2, 16)
        self.markerRatio = random.random()
        self.wallRatio = random.random()

    def mutate(self):
        self.rows += random.randint(-4, 4)
        self.rows = min(max(2, self.rows), 16)
        self.cols += random.randint(-4, 4)
        self.cols = min(max(2, self.cols), 16)
        self.markerRatio += random.random() * 0.5 - 0.25
        self.markerRatio = min(max(0, self.markerRatio), 1)
        self.wallRatio += random.random() * 0.5 - 0.25
        self.wallRatio = min(max(0, self.wallRatio), 1)

def train(model_weights):
    generators = [Generator() for _ in range(25)]
    best_generator = None
    best_reward = 0
    for episode in range(num_episodes):
        print('Episode {}'.format(episode))

        rewards = []
        for generator in generators:
            sample_size = generateExamples('programs', 'generated.json', generator.rows, generator.cols, generator.markerRatio, generator.wallRatio)
            #sample_size = generateExamples('../train_programs', 'generated.json', 8, 8, 0.8540125686836897, 0.05)
            if sample_size < 25:
                print(sample_size)
                fitness = 0
            else:
                fitness = 1 - evaluate_model(model_weights, 'data/1m_6ex_karel/new_vocab.vocab', 'generated.json', 5, 0, False, 'a', 64, 1, 8, True, False)
            print(generator.rows, generator.cols, generator.markerRatio, generator.wallRatio)
            print(fitness)
            rewards.append(fitness)
        sorted_parent_indexes = np.argsort(rewards)[::-1][:5]
        new_best_reward = rewards[sorted_parent_indexes[0]]
        if new_best_reward > best_reward:
            best_generator = generators[sorted_parent_indexes[0]]
            best_reward = new_best_reward
        children = []
        for idx in sorted_parent_indexes:
            child = deepcopy(generators[idx])
            children.append(child)
            for _ in range(4):
                child = deepcopy(generators[idx])
                child.mutate()
                children.append(child)
        generators = children

        print()
        print('Avg Reward: {}'.format(np.mean(rewards)))
        print('Best Reward: {}'.format(max(rewards)))
        print()
