from functools import partial
import jax
import jax.numpy as jnp
from typing import Tuple

from src.tasks.base import Task, EvalOutput


@partial(jax.jit, static_argnames=("solution_dim", "descriptor_dim"))
def _constant_fn(
    solution: jnp.ndarray, solution_dim: int, descriptor_dim: int
) -> Tuple[jnp.ndarray, jnp.ndarray]:

    objective = 1.0
    clipped = jnp.clip(solution, -5.12, 5.12)

    # Calculate descriptors by splitting the solution into `descriptor_dim` chunks.
    # Ensure descriptor_dim is a divisor of solution_dim for simplicity
    if solution_dim % descriptor_dim != 0:
        raise ValueError(
            "Solution dimension must be divisible by descriptor dimension."
        )

    chunk_size = solution_dim // descriptor_dim
    descriptors = jnp.array(
        [
            jnp.mean(clipped[i * chunk_size : (i + 1) * chunk_size])
            for i in range(descriptor_dim)
        ]
    )

    return objective, descriptors


class ConstantTask(Task):
    """Implements the Constant objective benchmark task."""

    def __init__(self, solution_dim: int, descriptor_dim: int):
        self.solution_size = (solution_dim,)
        self.descriptor_dim = descriptor_dim
        self.init_range = [-5.12, 5.12]

        # Concatenate objective and descriptors to compute the Jacobian at once
        def _combined_fn(solution):
            objective, descriptors = _constant_fn(
                solution, solution_dim, self.descriptor_dim
            )
            objective = jnp.expand_dims(objective, axis=0)
            return jnp.concatenate([objective, descriptors])

        self._vmapped_jac_fn = jax.jit(jax.vmap(jax.jacrev(_combined_fn)))
        self._vmapped_value_fn = jax.jit(jax.vmap(_combined_fn))

    @partial(jax.jit, static_argnames=("self", "return_grad"))
    def evaluate(
        self, solutions: jnp.ndarray, key: jnp.ndarray, return_grad: bool = True
    ) -> EvalOutput:
        combined_values = self._vmapped_value_fn(solutions)  # Shape: (K, 1 + m)
        fitnesses = combined_values[:, 0]
        descriptors = combined_values[:, 1:]

        if return_grad:
            jacobians = self._vmapped_jac_fn(solutions)  # Shape: (K, 1 + m, D)
            fitness_grads = jacobians[:, 0, :]
            descriptor_grads = jacobians[:, 1:, :]
        else:
            fitness_grads = jnp.zeros(1)
            descriptor_grads = jnp.zeros(1)

        return EvalOutput(
            fitnesses=fitnesses,
            descriptors=descriptors,
            fitness_grads=fitness_grads,
            descriptor_grads=descriptor_grads,
        )

    @partial(jax.jit, static_argnames=("self"))
    def vanilla_evaluate(
        self, solution: jnp.ndarray, key: jnp.ndarray
    ) -> Tuple[float, jnp.ndarray]:
        fit, desc = _constant_fn(solution, self.solution_size[0], self.descriptor_dim)
        return float(fit), desc

    @partial(jax.jit, static_argnames=("self", "n"))
    def get_random_solution(self, n, key):
        key, init_key = jax.random.split(key)
        init_solutions = jax.random.uniform(
            init_key,
            shape=(n, self.solution_size[0]),
            minval=self.init_range[0],
            maxval=self.init_range[1],
        )
        return init_solutions
