"""Functions necessary for running Pylic in the trajectory optimization marble
environment.
"""
from pylic.solvers.gradient_descent import solver as gd_solver
from pylic.solvers.cma_es import solver as cma_solver
from dataclasses import dataclass
from pylic.predicates import Constants
from pylic.predicates import Filter
from pylic.predicates import FunctionCall
from pylic.predicates import IfNode
from pylic.predicates import IfOr
from pylic.predicates import LessThan
from pylic.predicates import Predicate
from pylic.predicates import predicate_interpreter
from pylic.predicates import SolverFailedException
from plotting import plot_animation
from trajectory_dynamics import get_trajectory
from pylic.tape import Tape
from pylic.code_transformations import get_tape
from pylic.planner import concolic_planner
from pylic.planner import SearchNode
from pathlib import Path
from tasks import TaskImpulse as Task
from trajectory_dynamics import State
from trajectory_dynamics import simulation
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import ThreadPoolExecutor
from examples.plotting import SolverTime
import time
import math
import torch
import cma
import random
import datetime
import traceback
import multiprocessing as mp
ctx = mp.get_context('spawn')


@dataclass(frozen=True)
class AnnotatedIfOr(IfOr):
    annotation: str


Parameters = torch.Tensor


def get_reach_goal_predicate(task: Task) -> Predicate:
    """Returns a predicate that encodes whether the given control task
    is solved."""
    predicate = LessThan(
        FunctionCall(
            custom_function=lambda trace: (
                    trace[-1].program_state["state"].marble.position-task.goal_circle.position
                ).norm(),
            trace=Constants.input_trace,
        ),
        torch.tensor(task.goal_circle.radius),
    )
    return predicate


def is_sat(parameters: Parameters, task: Task) -> bool:
    """Returns whether the parameters solve the task."""
    robustness_value = predicate_interpreter(
        predicate=get_reach_goal_predicate(task),
        input_tape=get_tape(
            f=simulation,
            fixed_tape=None,
            actions=parameters,
            state=task.initial_state,
        ),
        max_value=torch.tensor(10),
        custom_functions=dict(),
        custom_filters=dict(),
    )
    return float(robustness_value) > 0.0


def gd_wrap(starting_parameters, **kwargs):
    """To avoid sending tensors across process boundaries."""
    # Remote call to Pylic's gradient solver
    result, _ = gd_solver(
        starting_parameters=torch.tensor(starting_parameters),
        **kwargs,
    )
    return result.tolist()


def marble_grad_solver(
        predicate: Predicate,
        starting_parameters: Parameters,
        initial_state: State,
        episode_timestep_n: int,
        random_restart_n: int,
        seed: int,
        worker_n: int,
        ) -> Parameters:
    """Wrapper around Pylic's gradient-based solver with two
    optimizations: (1) random initialization ad-hoc for the
    marble environment, (2) multi-processing.

    In the marble environment the parameters encode a
    time-indexed trajectory. This method optimizes a
    new suffix to the starting parameters.
    The return value are the starting parameters with
    the new suffix.
    """
    print(f"Starting parameters size: {starting_parameters.size()}")

    # Almost freeze the first part of the trajectory
    tape = get_tape(
        f=simulation,
        fixed_tape=None,
        actions=starting_parameters,
        state=initial_state
    )
    # First, check if starting parameters satisfy predicate already
    starting_robustness = predicate_interpreter(
        predicate=predicate,
        input_tape=tape,
        max_value=torch.tensor(10.0),
        custom_filters=CustomFilterMapping(),
        custom_functions=dict(),
    )
    if starting_robustness > 0.0:
        return starting_parameters

    # Otherwise, find the first collision with the highest indexed wall
    # to trim starting parameters
    start_t = 0
    collisions = get_collisions(tape)
    target_i = max([i for _, i in collisions], default=-1)
    for t, i in collisions:
        if i == target_i:
            start_t = t
            break
    trimmed_starting_parameters = starting_parameters[:start_t]

    # Seed a random number generator
    _random = random.Random(seed)

    # Reset the new part of the trajectory and unfreeze it
    initial_parameters = list()
    for _ in range(random_restart_n):
        # Define a gradient mask which will keep the original
        # trajectory frozen
        episode_parameters = torch.zeros((episode_timestep_n, 2))
        episodic_starting_parameters = torch.cat([
            trimmed_starting_parameters,
            episode_parameters,
        ])
        grad_mask = torch.full_like(
            episodic_starting_parameters,
            fill_value=0.0,
        )

        # Initial guess: a random suffix that applies the
        # same constant action
        if random_restart_n == 1:
            init_value = torch.tensor([0.0, 0.0])
        else:
            init_value = _random.choice([
                torch.tensor([0.0, 0.0]),
                torch.tensor([0.9, 0.0]),
                torch.tensor([-0.9, 0.0]),
                torch.tensor([0.0, -0.9]),
                torch.tensor([0.0, 0.9]),
                torch.tensor([_random.random()*2-1, _random.random()*2-1]),
                #torch.tensor([_random.random()*2-1, _random.random()*2-1]),
                #torch.tensor([_random.random()*2-1, _random.random()*2-1]),
                #torch.tensor([_random.random()*2-1, _random.random()*2-1]),
            ])
        if init_value.norm() > 0:
            init_value = 0.9*(init_value/init_value.norm())

        # Set the gradient mask and the initial guess
        for t in range(start_t, len(episodic_starting_parameters)):
            grad_mask[t] = 1.0
            episodic_starting_parameters[t] = init_value

        # Store the initial parameters
        initial_parameters.append((grad_mask, episodic_starting_parameters.detach().clone()))

    # Create an executor if multiprocessing was requested
    failed_parameters = list()

    for (grad_mask, episodic_starting_parameters) in initial_parameters:
        # Remote call to Pylic's gradient solver
        try:
            result = gd_wrap(
                starting_parameters=episodic_starting_parameters.tolist(),
                predicate=predicate,
                f=simulation,
                grad_mask=grad_mask,
                iter_n=100,
                learning_rate=0.2,
                momentum_beta=0.1,
                grad_norm_min=0.01,
                normalize_gradient=True,
                verbose=True,
                dampening=0.999,
                line_search_scale=0.9,
                line_search_max_iter_n=1,
                max_consecutive_non_improvements=5,
                max_consecutive_sat_improvements=3,
                max_value=torch.tensor(10.0),
                custom_functions=dict(),
                custom_filters=CustomFilterMapping(),
                state=initial_state,
            )
            result = torch.tensor(result)
            return result
        except SolverFailedException as e:
            traceback.print_exc()
            fp = e.final_parameters
            if fp is not None:
                failed_parameters.append(fp)
            print("Continuing...")

    # Shutdown executor if needed
    raise SolverFailedException(
        "Solver failed",
        final_parameters=None if len(failed_parameters) == 0 else failed_parameters[-1]
    )


def marble_cma_solver(
        predicate: Predicate,
        starting_parameters: Parameters,
        initial_state: State,
        cma_max_f_eval_n: int,
        episode_timestep_n: int,
        worker_n: int,
        seed: int,
        ) -> Parameters:
    """Wrapper around Pylic's CMA-ES solver.

    In the marble environment the parameters encode a
    time-indexed trajectory. This method optimizes a
    new suffix to the starting parameters.
    The return value are the starting parameters with
    the new suffix.

    This method does not apply optimizations specific to
    the marble environment.
    """
    # Reset the new part of the trajectory
    tape = get_tape(
        f=simulation,
        fixed_tape=None,
        actions=starting_parameters,
        state=initial_state
    )
    # First, check if starting parameters satisfy predicate already
    starting_robustness = predicate_interpreter(
        predicate=predicate,
        input_tape=tape,
        max_value=torch.tensor(10.0),
        custom_filters=CustomFilterMapping(),
        custom_functions=dict(),
    )
    if starting_robustness > 0.0:
        return starting_parameters
    start_t = max([t+1 for t, _ in get_collisions(tape)], default=0)

    # Otherwise, find the first collision with the highest indexed wall
    # to trim starting parameters
    start_t = 0
    collisions = get_collisions(tape)
    target_i = max([i for _, i in collisions], default=-1)
    for t, i in collisions:
        if i == target_i:
            start_t = t
            break
    trimmed_starting_parameters = starting_parameters[:start_t]

    prefix_trajectory = trimmed_starting_parameters
    if len(trimmed_starting_parameters) == 0:
        prefix_trajectory = torch.zeros((1, 2))

    def candidate_processor(p: Parameters) -> Parameters:
        full_parameters = [
            *(prefix_trajectory.tolist()),
            *(p.tolist()),
        ]
        return torch.tensor(full_parameters)

    starting_parameters = torch.zeros((episode_timestep_n, 2))

    # Call Pylic's CMA-ES solver
    parameters, _ = cma_solver(
        predicate,
        starting_parameters=starting_parameters,
        f=simulation,
        custom_functions=dict(),
        custom_filters=CustomFilterMapping(),
        max_f_eval_n=cma_max_f_eval_n,
        initial_stdev=0.3,
        verbose=True,
        multiprocessing_workers=worker_n,
        opts=cma.evolution_strategy.CMAOptions(
            seed=seed,
        ),
        candidate_processor=lambda x: x,
        state=initial_state,
        max_value=torch.tensor(10),
    )
    return candidate_processor(parameters)


def get_collisions(tape: Tape) -> list[tuple[int, int]]:
    """Return the list of (time, obstacle_i) collisions in
    the execution trace stored in the tape.
    Bouncing on the same obstacle is not reported as two
    different collisions.
    """
    collisions = list()
    collision_is = list()
    for node in tape:
        if not isinstance(node, IfNode):
            continue
        if float(node.value) <= 0.0:
            continue
        program_state = dict(node.variables_in_scope)
        i = program_state["i"]
        if len(collision_is) > 0 and i == collision_is[-1]:
            continue
        collision_is.append(i)
        collisions.append((program_state["t"], program_state["i"]))
    return collisions


def get_timestep_n(tape: Tape) -> int:
    """Return the number of simulated timesteps in the execution tape."""
    ts = set([0])
    for node in tape:
        if 't' in node.program_state.keys():
            ts.add(node.program_state['t'])
    return min(ts)


def select_node_depth_first(
        search_tree: SearchNode,
        explored_nodes: list[SearchNode],
        depth_bound: int,  # longest branch length
        max_time: float,
        ) -> SearchNode:
    """Select a node in the tree through depth-first
    search. This function is provided as an example showing how to design
    a tree-search with pylic."""
    # If we have exceeded the maximum recursion bound, raise error
    if depth_bound+1 == 0:
        raise ValueError("Cannot find valid node to explore!")

    # If we have exceeded the planning time, raise error
    if time.time() > max_time:
        raise ValueError("Planning time exceeded!")

    # If node is leaf, select it if and only if it has not
    # been explored
    node = search_tree
    if len(node.children) == 0:
        # Check if node has been explored before
        for explored_node in explored_nodes:
            p1 = node.predicate
            p2 = explored_node.predicate

            # Check if both are goal predicates (goal is the only LessThan
            # predicate)
            if isinstance(p1, LessThan) and isinstance(p2, LessThan)\
                    and node.parameters.equal(explored_node.parameters):
                raise ValueError("Cannot find node to explore!")

            # Otherwise, check the annotation
            # We want to avoid trying multiple times to collide with the same
            # obstacle when starting with the same parameters
            if not isinstance(p1, AnnotatedIfOr):
                continue
            if not isinstance(p2, AnnotatedIfOr):
                continue
            annotation1 = p1.annotation
            annotation2 = p2.annotation
            if annotation1 == annotation2\
                    and node.parameters.equal(explored_node.parameters):
                raise ValueError("Cannot find valid node to explore!")
        return node

    # Search depth first for a non-explored leaf
    children = list(search_tree.children)
    for child in children:
        try:
            selected_node = select_node_depth_first(
                search_tree=child,
                explored_nodes=explored_nodes,
                depth_bound=depth_bound-1,
                max_time=max_time,
            )
            return selected_node
        except ValueError:
            continue
    raise ValueError("Cannot find valid node to explore!")


class CustomFilterMapping:
    def __getitem__(self, filter_str: str):
        # Load filter values stored naively
        (min_t, o_i) = eval(filter_str)

        # Define custom filter semantics
        def custom_filter(tape: Tape, i: int) -> bool:
            node = tape[i]
            if node.id != "collision_check":
                return False
            program_state = node.program_state
            if "t" not in program_state.keys():
                return False
            if "i" not in program_state.keys():
                return False
            if program_state["t"] < min_t:
                return False
            if program_state["i"] != o_i:
                return False
            return True

        return custom_filter


def get_next_predicates(
        tape: Tape,
        task: Task,
        episode_timestep_n: int,
        seed: str,
        max_timesteps: int,
        ) -> list[Predicate]:
    """Return predicates that will be used to extend the search
    tree."""
    # If the simulation exceeds the number of timesteps, this node will not
    # be expanded
    if get_timestep_n(tape) >= max_timesteps:
        return list()

    # Create a list of predicates to expand node
    predicates = list[Predicate]()

    # Get a predicate to collide for each obstacle
    min_t = 0
    collisions = get_collisions(tape)
    target_i = max([i for _, i in collisions], default=-1)
    for t, i in collisions:
        if i == target_i:
            min_t = t
            break
    collided_is = [i for _, i in collisions]

    # Trying to reach the goal is always an option
    reach_goal = get_reach_goal_predicate(task)
    if len(collided_is) > 0:
        predicates.append(reach_goal)

    # Determine which segments we will try to collide with
    segment_is = [
        i
        for i in range(len(task.initial_state.segments))
        if all((
            # Collide at most 2 times with a segment
            collided_is.count(i) < 2,
            i >= max(collided_is, default=0)
        ))
    ]
    print("collided_is", collided_is)
    print("segment_is", segment_is)

    if 0 in segment_is:
        segment_is.remove(0)
    print(f"min_t: {min_t}")
    for i in segment_is:
        collide_predicate = AnnotatedIfOr(
            trace=Filter(
                custom_filter=str((min_t, i)),
                trace=Constants.input_trace,
            ),
            annotation=f"{[collided_is]+[i]}"
        )
        predicates.append(collide_predicate)

    return predicates


def get_child_parameters(tape: Tape, parameters: Parameters) -> Parameters:
    """Get the parameters that will be used by children of the given tape."""
    # Optimization works on a new suffix, so if there are no collisions
    # start from empty trajectory (implemented with a single timestep force).
    collisions = get_collisions(tape)
    if len(collisions) == 0:
        return torch.tensor([[0.0, 0.0]])
    return parameters


def serialize_tree(tree: SearchNode):
    children = [serialize_tree(c) for c in tree.children]
    root_predicate = tree.predicate
    if not isinstance(root_predicate, AnnotatedIfOr):
        annotation = str(root_predicate)
    else:
        annotation = root_predicate.annotation
    return dict(
        annotation=annotation,
        children=children,
    )


def solver_wrapper(solver, log_path, task, *args, **kwargs):
    try:
        result = solver(*args, **kwargs)
        #now_date = str(datetime.datetime.now())
        #log_file = log_path/f"{now_date}_solver_success.mp4"
        #plot_animation(
        #    states=get_trajectory(result, task.initial_state),
        #    fps=60,
        #    goal_position=tuple(task.goal_circle.position.tolist()),
        #    goal_radius=task.goal_circle.radius,
        #    output_path=log_file,
        #)
        #print(f"Wrote {log_file}")
        return result
    except SolverFailedException as e:
        #now_date = str(datetime.datetime.now())
        #log_file = log_path/f"{now_date}_solver_error.mp4"
        #plot_animation(
        #    states=get_trajectory(e.final_parameters, task.initial_state),
        #    fps=60,
        #    goal_position=tuple(task.goal_circle.position.tolist()),
        #    goal_radius=task.goal_circle.radius,
        #    output_path=log_file,
        #)
        #print(f"Wrote {log_file}")
        raise e


def pylic_grad_solve(
        task: Task,
        episode_timestep_n: int,
        random_restart_n: int,
        worker_n: int,
        seed: int,
        timeout_s: float,
        log_path: Path,
        ) -> tuple[list[tuple[SolverTime, Parameters]], bool]:
    """Solve the given task using Pylic's concolic planner."""
    _random = random.Random(seed)
    starting_parameters = torch.zeros((task.max_timesteps, 2))
    depth_bound = math.ceil(len(task.initial_state.segments)*1.5+1)
    max_time = time.time() + timeout_s
    start_t = time.time()

    log = list()

    def log_solver(*p):
        result = solver_wrapper(
            marble_grad_solver,
            log_path,
            task,
            *p,
            initial_state=task.initial_state,
            episode_timestep_n=episode_timestep_n,
            random_restart_n=random_restart_n,
            worker_n=worker_n,
            seed=_random.randint(0, 10000000),
        )
        log.append((time.time()-start_t, result))
        return result

    try:
        parameters = concolic_planner(
            f=simulation,
            is_sat=lambda *p: is_sat(*p, task=task),
            select_node=lambda *p: select_node_depth_first(
                *p,
                depth_bound=depth_bound,
                max_time=max_time,
            ),
            get_next_predicates=lambda tape, _, __: get_next_predicates(
                tape=tape,
                task=task,
                episode_timestep_n=episode_timestep_n,
                seed=str(_random.random()),
                max_timesteps=task.max_timesteps,
                ),
            solver=lambda *p: log_solver(*p),
            starting_parameters=starting_parameters,
            get_child_parameters=get_child_parameters,
            state=task.initial_state,
            verbose=True,
        )
        log.append((time.time()-start_t, parameters))
        return log, True
    except ValueError:
        traceback.print_exc()
        return log, False


def pylic_cma_solve(
        task: Task,
        episode_timestep_n: int,
        cma_max_f_eval_n: int,
        worker_n: int,
        seed: int,
        timeout_s: float,
        log_path: Path,
) -> tuple[list[tuple[SolverTime, Parameters]], bool]:
    """Solve the given task using Pylic's concolic planner."""
    _random = random.Random(seed)
    starting_parameters = torch.zeros((task.max_timesteps, 2))
    depth_bound = math.ceil(len(task.initial_state.segments)*1.5+1)
    max_time = time.time() + timeout_s
    start_t = time.time()

    log = list()

    def log_solver(*p):
        result = solver_wrapper(
            marble_cma_solver,
            log_path,
            task,
            *p,
            initial_state=task.initial_state,
            cma_max_f_eval_n=cma_max_f_eval_n,
            worker_n=worker_n,
            seed=seed,
            episode_timestep_n=episode_timestep_n,
        )
        log.append((time.time()-start_t, result))
        return result

    try:
        parameters = concolic_planner(
            f=simulation,
            is_sat=lambda *p: is_sat(*p, task=task),
            select_node=lambda *p: select_node_depth_first(
                *p,
                depth_bound=depth_bound,
                max_time=max_time,
            ),
            get_next_predicates=lambda tape, _, __: get_next_predicates(
                tape=tape,
                task=task,
                episode_timestep_n=episode_timestep_n,
                seed=str(_random.random()),
                max_timesteps=task.max_timesteps,
                ),
            solver=lambda *p: log_solver(*p),
            starting_parameters=starting_parameters,
            get_child_parameters=get_child_parameters,
            state=task.initial_state,
            verbose=True,
        )
        log.append((time.time()-start_t, parameters))
        return log, True
    except ValueError:
        traceback.print_exc()
        return log, False
