import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from jax.flatten_util import ravel_pytree

from kheperax.maze import Maze

from kheperax.task import KheperaxTask, KheperaxConfig

seed = 42

random_key = jax.random.PRNGKey(seed)
random_key, subkey = jax.random.split(random_key)

config_kheperax = KheperaxConfig.get_default()
config_kheperax.episode_length = 256

(env, policy_network, scoring_fn,) = KheperaxTask.create_default_task(
    config_kheperax, 
    random_key = subkey,
)

obs = jnp.zeros(shape=(1, env.observation_size))
example_init_params = policy_network.init(random_key, obs)

flattened_params, array_to_pytree_fn = ravel_pytree(example_init_params)

print(f"flattened_params size = {flattened_params.shape[0]}")

random_key, subkey = jax.random.split(random_key)

state = env.reset(subkey)

print(state)

def get_score(state_descriptions):
    #converts a list of state descriptions to a vector of scores
    #goal is 0.15, 0.9
    #print(f"state desc {state_descriptions.shape}")
    distances = np.sum((state_descriptions - np.array([0.15, 0.9]))**2, axis=-1)
    delta_diff = np.diff(distances, axis=-1)
    delta_diff = np.hstack((delta_diff, np.zeros((delta_diff.shape[0], 1))))
    #print(f"delta_diff {delta_diff.shape}")
    delta_diff[..., 0] += 0.75**2
    
    return delta_diff

#fitness over time, deaggregated wrt to time
def time_deagg(rewards, descriptors, n):
    #score for each individual , and each objective
    deagg_fits = np.zeros((rewards.shape[0], n))

    size_unit = rewards.shape[1] // n
    
    for i in range(n):
        deagg_fits[:, i] = np.sum(rewards[:, i*size_unit:(i+1)*size_unit], axis=1)

    return deagg_fits

#fitness over space, deaggregated wrt to space
# from li ding
def quad_tree_deagg(rewards, descriptors, n):
    """Returns a series of deaggregated objectives for a batch of solutions."""
    num_splits = int(np.sqrt(n))
    assert num_splits >= 1  # make sure we have at least one objective
    n_objectives = n
    objectives = np.zeros((rewards.shape[0], n_objectives))

    # split the rewards into num_splits
    for i in range(num_splits):
        for j in range(num_splits):
            mask = np.zeros(rewards.shape)
            mask[
                (descriptors[:, :, 0] >= i / num_splits)
                & (descriptors[:, :, 0] < (i + 1) / num_splits)
                & (descriptors[:, :, 1] >= j / num_splits)
                & (descriptors[:, :, 1] < (j + 1) / num_splits)
            ] = 1

            objectives[:, i * num_splits + j] = np.sum(rewards * mask, axis=1)

            no_valid_obj = np.sum(mask, axis=1) == 0
            objectives[no_valid_obj, i * num_splits + j] = -100  # large negative number

    return objectives


def evaluate(params, random_key, n=1, deaggregation="time", domain_type="decep"):
    if domain_type not in ["decep", "illu"]:
        print(f"domain_type not valid")
        return

    params = jnp.asarray(params)

    params_pytree = jax.vmap(array_to_pytree_fn)(params)

    random_key, subkey = jax.random.split(random_key)
    fitness, descriptor, info, _ = scoring_fn(params_pytree, random_key)
    
    desc = info['transitions'].state_desc

    if domain_type == "decep":
        delta_diffs = get_score(desc)

        fitness = -np.sum(delta_diffs, axis=-1)
        de_agg_fitness = -delta_diffs
    else:
        fits = info['transitions'].rewards
        #fitness = np.sum(fits, axis=-1)
        de_agg_fitness = fits

    if deaggregation == "time":
        deagg_fits = time_deagg(de_agg_fitness, desc, n)
    elif deaggregation == "space":
        deagg_fits = quad_tree_deagg(de_agg_fitness, desc, n)
    else:
        raise ValueError(f"deaggregation {deaggregation} not recognized")
    #print(f"fitness: {fitness}")
    #print(f"de_agg_fitness: {de_agg_fitness}")
    #print(f"deagg_fits: {deagg_fits}")
    #assert(np.sum(de_agg_fitness) - np.sum(deagg_fits) < 1e-5 * de_agg_fitness.shape[0])

    # TODO: casting values depending on return_type as suggested by Bryon
    return np.asarray(deagg_fits), np.asarray(fitness), np.asarray(descriptor), info

#make a random individual and evaluate:

individual = jax.random.uniform(random_key, shape=(1, flattened_params.shape[0]))
deagg_fits, fitness, descriptor, info = evaluate(individual, random_key, n=4, deaggregation="space", domain_type="illu")
print(fitness, descriptor, info)

print(f"params size = {flattened_params.shape[0]}")
print(f"fitness size = {fitness.shape[0]}, fitness = {fitness}")
print(f"descriptor size = {descriptor.shape[0]}, descriptor = {descriptor}")

desc = info['transitions'].state_desc
print(f"state descriptions: {desc}, shape = {desc.shape}")
print(f"distances: {get_score(desc)}")

print(f"fitness: {fitness}")
print(f"descriptor: {descriptor[0]}")
print(f"descriptor_dist: {(descriptor-[0.15, 0.9])**2}")
print(f"acc fitess: {np.sum((np.asarray(descriptor)-[0.15, 0.9])**2)}")
