"""Pylic adapter for the trajectory optimization case, using the
Cross-entropy method.
"""
from collections import defaultdict
from pylic.tape import Tape
from pylic.tape import IfNode
from pylic.tape import ForNode
from pylic.code_transformations import get_tape
from pylic.predicates import Constants, Predicate
from pylic.predicates import Negation
from pylic.predicates import Conjunction
from pylic.predicates import LessThan
from pylic.predicates import predicate_interpreter
from pylic.predicates import SolverFailedException
from pylic.predicates import FunctionCall
from pylic.predicates import Filter
from pylic.predicates import IfOr
from pylic.solvers.cma_es import solver as cma_solver
from pylic.planner import SearchNode
from dynamics import simulation
from pylic.planner import concolic_planner
from dataclasses import dataclass
import random
import traceback
import torch
import cma
import time


@dataclass(frozen=True)
class AnnotatedConjunction(Conjunction):
    annotation: str


Parameters = torch.Tensor


def select_node_depth_first(
        search_tree: SearchNode,
        explored_nodes: list[SearchNode],
        depth_bound: int,  # longest branch length
        ) -> SearchNode:
    """Select a node in the tree through depth-first
    search. This function is provided as an example showing how to design
    a tree-search with pylic. This example assumes that the parameters support
    the `!=` operator."""
    if depth_bound+1 == 0:
        raise ValueError("Cannot find valid node to explore")
    # If node is leaf, select it if and only if it has not
    # been explored
    node = search_tree
    if len(node.children) == 0:
        # Check if node has been explored before
        for explored_node in explored_nodes:
            if explored_node.predicate is node.predicate:
                if node.parameters.equal(explored_node.parameters):
                    raise ValueError("Cannot find valid node to explore!")
                else:
                    continue
            p1 = node.predicate
            p2 = explored_node.predicate
            # We want to avoid trying multiple times to collide with the same
            # obstacle when starting with the same parameters
            if not isinstance(p1, AnnotatedConjunction):
                continue
            if not isinstance(p2, AnnotatedConjunction):
                continue
            annotation1 = p1.annotation
            annotation2 = p2.annotation
            if annotation1 == annotation2\
                    and node.parameters.equal(explored_node.parameters):
                raise ValueError("Cannot find valid node to explore!")
        return node
    # Search depth first for a non-explored leaf
    children = list(search_tree.children)
    random.shuffle(children)
    for child in children:
        try:
            selected_node = select_node_depth_first(
                search_tree=child,
                explored_nodes=explored_nodes,
                depth_bound=depth_bound-1,
            )
            return selected_node
        except ValueError:
            continue
    raise ValueError("Cannot find valid node to explore")


def distance_to_goal(tape: Tape) -> torch.Tensor:
    distances = defaultdict(list)
    for node in tape:
        if not (isinstance(node, ForNode) and node.id == 1):
            continue
        state = node.program_state
        ant_position = torch.tensor(state["ant_position"])
        env = state["env"]
        t = state["t"]
        goal_position = torch.tensor((env.goal.x, env.goal.y))
        distance = (ant_position-goal_position).norm()
        distances[t].append(distance)
    final_distances = distances[max(distances.keys())]
    return torch.stack(final_distances).min()


reach_goal_predicate = LessThan(
        left=FunctionCall(
            custom_function=distance_to_goal,
            trace=Constants.input_trace,
        ),
        right=torch.tensor(0.01),
    )


def is_sat(
        parameters: Parameters,
        **simulation_kwargs,
        ) -> bool:
    """Parameters are satisfactory if and only if the ant reached the goal."""
    robustness_value = predicate_interpreter(
        reach_goal_predicate,
        get_tape(
            f=simulation,
            fixed_tape=None,
            parameters=parameters,
            **simulation_kwargs,
        ),
        max_value=torch.tensor(10),
        custom_functions=dict(),
        custom_filters=dict(),
    )
    return robustness_value.item() > 0.0


def get_button_push_times(tape: Tape) -> list[tuple[int, int]]:
    """Helper function to get the button push times (t, i) in the tape."""
    times = list()
    for node in tape:
        # If node is not button push check, ignore
        if not isinstance(node, IfNode):
            continue
        # If push check is false, ignore
        if float(node.value) <= 0.0:
            continue
        program_state = node.program_state
        if 't' not in program_state.keys():
            continue
        if 'i' not in program_state.keys():
            continue
        t = program_state["t"]
        i = program_state["i"]
        times.append((t, i))
    times = sorted(list(set(times)))
    print(f"Discovered tape: {times}")
    return times


def get_next_predicates(
        tape: Tape,
        _: Parameters,
        __: Predicate,
        password_len: int,
        episode_timestep_n: int,
        ) -> list[Predicate]:
    """Return the children of the given tape."""
    push_times = get_button_push_times(tape)
    pushed_is = set([i for _, i in push_times])
    pushed_times_t = set([t for t, _ in push_times])

    # Reaching the goal is always an option
    if len(set(pushed_is)) == password_len:
        return [reach_goal_predicate]

    # Get the indices of nonpushed buttons in the trace
    # This is to "discover" what other buttons exist
    # from the trace.
    nonpushed_button_is = set()
    for node in tape:
        if not isinstance(node, IfNode):
            continue
        if 'i' not in node.program_state.keys():
            continue
        i = node.program_state["i"]
        if i not in pushed_is:
            nonpushed_button_is.add(i)

    # Get the start and end of the current episode
    max_push_t = max(pushed_times_t, default=-1)
    start_t = max_push_t+1
    end_t = start_t+4*episode_timestep_n  # could be omitted

    # Add predicates for every button that has not been pushed
    predicates = list()
    for nonpushed_button_i in nonpushed_button_is:
        # Filter trace to the nodes between start_t and end_t
        local_trace = Filter(
            custom_filter = lambda trace, node_i:\
                't' in trace[node_i].program_state.keys()\
                and 'i' in trace[node_i].program_state.keys()\
                and trace[node_i].program_state['t'] < end_t\
                and trace[node_i].program_state['t'] >= start_t,
            trace=Constants.input_trace,
        )

        # Filter for pushing this button condition check
        node_filter = Filter(
            custom_filter = lambda trace, node_i:\
                trace[node_i].program_state['i'] == nonpushed_button_i,
            trace=local_trace,
        )
        push_nonpushed_predicate = IfOr(node_filter)

        # Filter for pushing any other nonpushed button
        other_buttons = set(nonpushed_button_is)-set([nonpushed_button_i])
        if len(other_buttons) > 0:
            other_node_filter = Filter(
                custom_filter = lambda trace, node_i:\
                    trace[node_i].program_state['i'] in other_buttons,
                trace=local_trace,
            )
            dont_push_others = Negation(IfOr(other_node_filter))
        else:
            dont_push_others = Constants.Top

        # Build a predicate for "push this nonpushed button but not any other"
        new_predicate = AnnotatedConjunction([
            push_nonpushed_predicate,
            dont_push_others,
            ],
            annotation=f"Push {nonpushed_button_i}",
         )
        predicates.append(new_predicate)

    print(f"Next predicates: {predicates}")
    return predicates


def get_child_parameters(tape: Tape, parameters: Parameters) -> Parameters:
    """Get the parameters that will be used by children of the given tape."""
    return parameters


def local_cma_solver(
        predicate: Predicate,
        starting_parameters: Parameters,
        max_episodes_per_predicate: int,
        episode_timestep_n: int,
        actuator_n: int,
        max_f_eval_n: int,
        worker_n: int,
        **simulation_kwargs,
        ) -> Parameters:
    """Solve the predicate by optimizing the tail of the trajectory."""
    # Trim parameters up to the last collision
    tape = get_tape(
        simulation,
        None,
        starting_parameters,
        **simulation_kwargs,
    )
    push_ts = [t for t, _ in get_button_push_times(tape)]
    max_push_t = max(push_ts, default=-1)
    frozen_parameters = list()
    for t in range(max_push_t+1):
        frozen_parameters.append(starting_parameters[t])

    # Assume we are optimizing the suffix of the trajectory only
    # Restart a few times increasing episode length
    # Initial repeating symbol for the tail:
    for tail_factor in range(1, max_episodes_per_predicate+1):
        starting_tail = torch.zeros(
            (episode_timestep_n, 1, actuator_n)
        )

        def candidate_processor(candidate_tail: Parameters) -> Parameters:
            if len(frozen_parameters) > 0:
                real_parameters = list()
                real_parameters.append(torch.stack(frozen_parameters))
                for _ in range(tail_factor):
                    real_parameters.append(candidate_tail)
                real_p = torch.cat(real_parameters)
                return real_p
            return candidate_tail

        try:
            p, _ = cma_solver(
                predicate,
                starting_parameters=starting_tail,
                f=simulation,
                max_value=torch.tensor(10),
                max_f_eval_n=max_f_eval_n,
                custom_filters=dict(),
                custom_functions=dict(),
                initial_stdev=0.1,
                verbose=True,
                multiprocessing_workers=worker_n,
                #multiprocessing_workers=None,
                opts=cma.evolution_strategy.CMAOptions(
                    popsize=worker_n,
                    seed=289518,
                    tolfun=0.001,
                ),
                candidate_processor=candidate_processor,
                **simulation_kwargs,
            )
            return candidate_processor(p)
        except SolverFailedException as e:
            if tail_factor == max_episodes_per_predicate:
                raise SolverFailedException(
                    e.message,
                    candidate_processor(e.final_parameters),
                )
            pass
    raise SolverFailedException("Solver failed!")


SolverTime = float
Log = list[tuple[SolverTime, Parameters]]


def pylic_cma_solver(
        target_password: tuple[int],
        episode_timestep_n: int,
        actuator_n: int,
        button_n: int,
        max_f_eval_n: int,
        max_episodes_per_predicate: int,
        worker_n: int,
        sub_step_s: float,
        ) -> tuple[Parameters, Log]:
    password = tuple(target_password)

    # Instantiate log
    log = list()

    # Wrap log to track time and current parameters
    start_time = time.time()

    def solver_wrapper(p, i, log):
        try:
            result = local_cma_solver(
                p, i,
                password=password,
                max_f_eval_n=max_f_eval_n,
                max_episodes_per_predicate=max_episodes_per_predicate,
                episode_timestep_n=episode_timestep_n,
                actuator_n=actuator_n,
                worker_n=worker_n,
                num_buttons=button_n,
                sub_step_s=sub_step_s,
            )
            log.append((time.time()-start_time, result))
            return result
        except SolverFailedException as e:
            log.append((time.time()-start_time, e.final_parameters))
            raise e

    # Solve task
    starting_parameters = torch.zeros((episode_timestep_n, 1, actuator_n))
    final_parameters = starting_parameters
    try:
        parameters = concolic_planner(
            f=simulation,
            is_sat=lambda p: is_sat(
                p,
                password=password,
                num_buttons=button_n,
                sub_step_s=sub_step_s,
            ),
            get_next_predicates=lambda *p: get_next_predicates(
                *p,
                password_len=len(password),
                episode_timestep_n=episode_timestep_n,
            ),
            select_node=lambda *p: select_node_depth_first(
                *p,
                depth_bound=len(password)+1,
            ),
            solver=lambda p, i: solver_wrapper(p, i, log),
            starting_parameters=starting_parameters,
            get_child_parameters=get_child_parameters,
            verbose=True,
            password=password,
            num_buttons=button_n,
            sub_step_s=sub_step_s,
        )
    except ValueError:
        # We probably could not find a node to expand
        print(traceback.format_exc())
        parameters = starting_parameters
        raise SolverFailedException("Failed to find solution!", final_parameters)
    return parameters, log
