from evotorch.algorithms import Cosyne
from tasks import TaskImpulse as Task
import torch
import random
from policy_dynamics import State
from policy_dynamics import simulation
from pylic.solvers.evotorch_neuroevolution import solver as evotorch_solver
from pylic.predicates import Predicate
from pylic.predicates import predicate_interpreter
from pylic.tape import IfNode
from pylic.tape import Tape
from pylic.code_transformations import get_tape
from policy_dynamics import get_coordinate_scale
from policy_dynamics import get_observation
from policy_dynamics import get_policy_actions
from pylic.planner import concolic_planner
from pylic.planner import SearchNode
from trajectory_pylic import get_reach_goal_predicate
from trajectory_pylic import get_next_predicates
import time
import torch.nn as nn


Parameters = torch.Tensor


def is_sat(parameters: torch.nn.Module, step_n: int, task: Task) -> bool:
    """Returns whether the parameters solve the task."""
    robustness_value = predicate_interpreter(
        predicate=get_reach_goal_predicate(task),
        input_tape=get_tape(
            f=simulation,
            fixed_tape=None,
            policy=parameters,
            state=task.initial_state,
            step_n=step_n,
        ),
        max_value=torch.tensor(10),
        custom_functions=dict(),
        custom_filters=dict(),
    )
    return float(robustness_value) > 0.0


def select_node_depth_first(
        search_tree: SearchNode,
        explored_nodes: list[SearchNode],
        depth_bound: int,  # longest branch length
        max_time: float,
        ) -> SearchNode:
    """Select a node in the tree through depth-first
    search. This function is provided as an example showing how to design
    a tree-search with pylic."""
    # If we have exceeded the maximum recursion bound, raise error
    if depth_bound+1 == 0:
        raise ValueError("Cannot find valid node to explore!")

    # If we have exceeded the planning time, raise error
    if time.time() > max_time:
        raise ValueError("Planning time exceeded!")

    # If node is leaf, select it if and only if it has not
    # been explored
    node = search_tree
    if len(node.children) == 0:
        # Check if node has been explored before
        for explored_node in explored_nodes:
            if node.predicate == explored_node.predicate:
                # Parameters are pytorch networks, so we
                # check equality of parameters
                # https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/2
                model1 = node.parameters
                model2 = explored_node.parameters
                parameters_are_equal = True
                for p1, p2 in zip(model1.parameters(), model2.parameters()):
                    if p1.data.ne(p2.data).sum() > 0:
                        parameters_are_equal = False
                        break
                if parameters_are_equal:
                    raise ValueError("Cannot find valid node to explore!")
        return node
    # Search depth first for a non-explored leaf
    for child in search_tree.children:
        try:
            selected_node = select_node_depth_first(
                search_tree=child,
                explored_nodes=explored_nodes,
                depth_bound=depth_bound-1,
                max_time=max_time,
            )
            return selected_node
        except ValueError:
            continue
    raise ValueError("Cannot find valid node to explore!")


def local_evotorch_solver(
        predicate: Predicate,
        starting_network: torch.nn.Module,
        initial_state: State,
        episode_timestep_n: int,
        seed: int,
        worker_n: int,
        max_generations: int,
        ) -> torch.nn.Module:
    """Wrapper around Pylic's CMA-ES solver."""
    def algorithm(problem):
        # Hyperparameters from
        # https://docs.evotorch.ai/v0.4.0/examples/notebooks/Gym_Experiments_with_PGPE_and_CoSyNE/#training-with-cosyne
        return Cosyne(
            problem,
            popsize=50,
            tournament_size=4,
            mutation_stdev=0.3,
        )

    # Call Pylic's CMA-ES solver
    parameters = evotorch_solver(
        predicate=predicate,
        network=starting_network,
        f=simulation,
        max_value=torch.tensor(10),
        num_generations=max_generations,
        algorithm=algorithm,
        custom_functions=dict(),
        custom_filters=dict(),
        verbose=True,
        state=initial_state,
        step_n=episode_timestep_n,
        # Evotorch crashes if I set a seed and run parallel workers at the
        # same time. It's either setting a seed, or parallel workers
        evotorch_kwargs=dict(
            #seed=seed,
            num_actors=worker_n,
        )
    )
    return parameters


def get_collisions(tape: Tape) -> list[tuple[int, int]]:
    """Return the list of (time, obstacle_i) collisions in
    the execution trace stored in the tape.
    Bouncing on the same obstacle is not reported as two
    different collisions.
    """
    collisions = list()
    collision_is = list()
    for node in tape:
        if not isinstance(node, IfNode):
            continue
        if float(node.value) <= 0.0:
            continue
        program_state = dict(node.variables_in_scope)
        i = program_state["i"]
        if len(collision_is) > 0 and i == collision_is[-1]:
            continue
        collision_is.append(i)
        collisions.append((program_state["t"], program_state["i"]))
    return collisions


def get_child_parameters(_: Tape, parameters: torch.nn.Module) -> torch.nn.Module:
    """Get the parameters that will be used by children of the given tape."""
    return parameters


def pylic_evotorch_solve(
        task: Task,
        seed: int,
        episode_timestep_n: int,
        worker_n: int,
        max_generations_per_episode: int,
        timeout_s: float,
        ) -> Parameters:
    """Solve the given task using Pylic's concolic planner with CoSyNE as a
    solver. CoSyNE optimizes a neural policy, which is converted to a
    sequence of actions by performing a rollout."""
    _random = random.Random(seed)
    max_time = time.time() + timeout_s
    coordinate_scale = get_coordinate_scale(task.initial_state)
    input_size = len(get_observation(
        task.initial_state,
        t=0,
        T=1,
        coordinate_scale=coordinate_scale
    ))
    torch.manual_seed(seed)
    network = nn.Sequential(
              nn.Linear(input_size, 32),
              nn.ReLU(),
              nn.Linear(32, 2),
            )
    policy = concolic_planner(
        f=simulation,
        is_sat=lambda *p: is_sat(*p, step_n=task.max_timesteps, task=task),
        select_node=lambda *p: select_node_depth_first(
            *p,
            depth_bound=18,
            max_time=max_time,
        ),
        get_next_predicates=lambda tape, _, __: get_next_predicates(
            tape=tape,
            task=task,
            seed=str(_random.random()),
            episode_timestep_n=episode_timestep_n,
            max_timesteps=task.max_timesteps,
            ),
        solver=lambda *p: local_evotorch_solver(
            *p,
            initial_state=task.initial_state,
            seed=seed,
            worker_n=worker_n,
            max_generations=max_generations_per_episode,
            episode_timestep_n=episode_timestep_n,
            ),
        starting_parameters=network,
        get_child_parameters=get_child_parameters,
        state=task.initial_state,
        verbose=True,
        step_n=episode_timestep_n,
    )

    # Convert policy to sequence of actions
    parameters = get_policy_actions(
        policy=policy,
        state=task.initial_state,
        step_n=episode_timestep_n,
    )
    return torch.stack(parameters)
