"""Facade on Pylic for solving marble drop puzzles."""
from pathlib import Path
from examples.marble_drop.environment import Puzzle
from examples.marble_drop.environment import Action
from examples.marble_drop.environment import ButtonRow
from examples.marble_drop.environment import FPS
from examples.marble_drop.environment import simulation
from examples.marble_drop.environment import get_puzzle_button_n
from examples.marble_drop.environment import get_puzzle_width
from examples.marble_drop.environment import get_puzzle_height
from examples.plotting import SolverTime
from pylic.planner import concolic_planner
from pylic.planner import select_node_depth_first
from pylic.planner import SearchNode
from pylic.predicates import IfOr, SolverFailedException
from pylic.predicates import Filter
from pylic.predicates import Predicate
from pylic.predicates import Constants
from pylic.tape import Tape
from pylic.solvers.cma_es import solver as cma_solver
import random
import torch
import cma
import traceback
import time


FPS_PER_ACTION = FPS*2  # 2 seconds per action is more than enough


def is_sat(actions: list[Action], puzzle: Puzzle) -> bool:
    """Return whether the given actions solve the given puzzle."""
    timestep_n = int(FPS_PER_ACTION*len(actions))
    pressed_buttons = simulation(actions, puzzle, None, timestep_n)
    button_n = get_puzzle_button_n(puzzle)
    remaining_n = button_n - len(pressed_buttons)
    print(f"Is remaining {remaining_n} SAT? {remaining_n == 0}")
    return remaining_n == 0


def get_next_predicates(actions: list[Action], puzzle: Puzzle) -> list[Predicate]:
    """Return the predicates to expand the tree search given the current
    input."""
    # If the task is SAT, then there is nothing else to do
    if is_sat(actions, puzzle):
        return []

    # Otherwise, there is at least one button that has not been pressed
    # Simulate to obtain the pressed buttons
    timestep_n = int(FPS_PER_ACTION*len(actions))
    pressed_buttons = simulation(actions, puzzle, None, timestep_n)

    # Identify buttons that have not been pressed and organize them
    # by rows
    nonpressed_row_buttons = list()
    for i, row in enumerate(puzzle):
        # Identify pressed buttons in the row
        pressed_row_buttons = [
            row_button_i
            for (row_i, row_button_i) in pressed_buttons
            if row_i == i
        ]

        # Identify number of buttons in row
        if isinstance(row, ButtonRow):
            row_button_n = row.button_n
        else:
            row_button_n = 1

        # Identify buttons that have not been pressed
        nonpressed = set(range(row_button_n)) - set(pressed_row_buttons)
        nonpressed_row_buttons.append(nonpressed)

    # Identify first row that has not been completed successfully
    first_unsat_row_i = min(
        i
        for i, buttons in enumerate(nonpressed_row_buttons)
        if len(buttons) > 0
    )

    # Build list of predicates with the buttons in the row that have not been
    # pressed
    predicates = list()
    for button_i in nonpressed_row_buttons[first_unsat_row_i]:
        # We will naively represent the filter with the str value of the
        # desired pressed buttons, indicating that the last entry
        # defines the filter.
        #
        # The only reason we store the entire value
        # is so that we can distinguish predicates by using the `!=` operator.
        # So, because all predicates are simply `IfOr` of different filters,
        # different desired pressed button combinations will be marked
        # as different under `!=`.
        target_password = pressed_buttons + [(first_unsat_row_i, button_i)]
        trace = Filter(str(target_password), Constants.input_trace)
        predicate = IfOr(trace)
        predicates.append(predicate)
    return predicates


def sim_wrapper(p: torch.Tensor, puzzle: Puzzle, timestep_n: int):
    """Simulation wrapper to conform to solver interface."""
    local_actions = p.tolist()
    return simulation(local_actions, puzzle, None, timestep_n)


class CustomFilterMapping:
    def __getitem__(self, filter_str: str):
        # Load filter values stored naively
        (row_i, button_i) = eval(filter_str)[-1]

        # Define custom filter semantics
        def custom_filter(tape: Tape, i: int) -> bool:
            node = tape[i]
            if node.id != "collision_check":
                return False
            program_state = node.program_state
            if "row_i" not in program_state.keys():
                return False
            if "button_i" not in program_state.keys():
                return False
            if program_state["row_i"] != row_i:
                return False
            if program_state["button_i"] != button_i:
                return False
            return True

        return custom_filter


def select_node_depth_first_with_fallback(
        search_tree: SearchNode,
        explored_nodes: list[SearchNode],
        depth_bound: int,  # longest branch length
        puzzle: Puzzle,
        timeout_unix: float,
        ) -> SearchNode:
    """Select depth-first unexplored if possible, otherwise do uniform
    node sampling."""
    if time.time() > timeout_unix:
        raise ValueError("Cannot find node because of timeout!")

    # First try depth-first unexplored
    try:
        return select_node_depth_first(search_tree, explored_nodes, depth_bound)
    except ValueError as e:
        print(f"Depth-first error: {e}")
        pass

    # Otherwise, flatten tree and uniform node sampling
    def flatten_tree(search_tree: SearchNode):
        nodes = [search_tree]
        for child in search_tree.children:
            nodes.extend(flatten_tree(child))
        return nodes
    nodes = flatten_tree(search_tree)

    # Assign scores to each node
    weights = list()
    prev_node_keys = [
        (node.parameters, node.predicate)
        for node in nodes
    ]
    for node in nodes:
        actions = node.parameters
        timestep_n = int(FPS_PER_ACTION*(len(actions)+1))
        pressed_buttons = simulation(actions, puzzle, None, timestep_n)
        score = 1.0+len(pressed_buttons)/max(1, get_puzzle_button_n(puzzle))
        key = (node.parameters, node.predicate)
        weight = score/max(1, prev_node_keys.count(key))
        weights.append(weight)

    if len(nodes) == 0:
        raise ValueError("Cannot find valid node to explore")

    return random.choices(nodes, weights=weights, k=1)[0]


def local_cma_solver(
        predicate: Predicate,
        actions: list[Action],
        puzzle: Puzzle,
        worker_n: int,
        random_restart_n: int,
        ) -> list[Action]:
    """Return actions that satisfy the given predicate using CMA-ES."""
    if predicate is Constants.Top:
        return actions

    timestep_n = int(FPS_PER_ACTION*(min(3, len(actions)+2)))
    initializations = [
        [
            0.0,
            0.0,
            0.0,
            0.0,
        ],
    ]

    def preprocessor(p: torch.Tensor) -> torch.Tensor:
        # Helper to concatenate the given tensor to the list
        # of actions
        action = tuple[float, float, float, float](p.tolist())
        return torch.tensor(actions + [action])

    # Repeat initializations to account for random restart number
    # (different values will be sampled in the first iteration of CMA-ES
    # even with the same starting value, so it is enough to repeat starting
    # values).
    from itertools import cycle
    for init in cycle(list(initializations)):
        if len(initializations) >= random_restart_n:
            break
        initializations.append(init)

    # Call inner solver on each initialization
    for init in initializations:
        starting_parameters = torch.tensor(init)
        try:
            result, _ = cma_solver(
                predicate=predicate,
                starting_parameters=starting_parameters,
                f=simulation,
                max_value=torch.tensor(10.0),
                custom_functions=dict(),
                custom_filters=CustomFilterMapping(),
                initial_stdev=1.0,
                max_f_eval_n=1000,
                verbose=True,
                multiprocessing_workers=worker_n,
                opts=cma.evolution_strategy.CMAOptions(
                    popsize=worker_n*2,
                    #seed=289518,
                    tolfun=0.0,
                    tolflatfitness=500,
                ),
                candidate_processor=preprocessor,
                puzzle=puzzle,
                timestep_n=timestep_n,
                output_animation_path=None,
            )
            return preprocessor(result).tolist()
        except SolverFailedException:
            pass
    raise SolverFailedException("Solver failed!")


def pylic_cma_solver(
        puzzle: Puzzle,
        timeout_s: float,
        worker_n: int,
        random_restart_n: int,
        ) -> tuple[list[tuple[list[Action], SolverTime]], bool]:
    """Use pylic to find and return a list of actions that solve the maze.
    Random restart controls the inner numerical search random restart number."""
    button_n = get_puzzle_button_n(puzzle)
    max_actions = button_n
    timestep_n = int(FPS_PER_ACTION*max_actions)
    start_t = time.time()
    timeout_unix = start_t + timeout_s

    # Keep track of best
    log = [(list(), 0.0)]  # Log initial parameters

    def solver_wrap(p: Predicate, a: list[Action], log) -> list[Action]:
        actions = local_cma_solver(
            p, a, puzzle, worker_n, random_restart_n
        )
        log.append((actions, time.time()-start_t))
        return actions

    try:
        result = concolic_planner(
            f=simulation,
            is_sat=lambda a: is_sat(a, puzzle),
            select_node=lambda t, n: select_node_depth_first_with_fallback(
                t,
                n,
                button_n,
                puzzle,
                timeout_unix,
            ),
            get_next_predicates=lambda _, a, __: get_next_predicates(a, puzzle),
            solver=lambda p, a: solver_wrap(p, a, log),
            starting_parameters=list(),
            get_child_parameters=lambda _, p: p,
            verbose=True,
            puzzle=puzzle,
            output_animation_path=None,
            timestep_n=timestep_n,
        )
        log.append((result, time.time()-start_t))
        return log, True
    except ValueError:
        traceback.print_exc()
        return log, False


if __name__ == "__main__":
    # Debug solver with task
    puzzle: Puzzle = [
        ButtonRow(2),
        ButtonRow(2),
        ButtonRow(1),
    ]

    # Inspect transformed simulation code
    import inspect
    from pylic.code_transformations import get_tracing_transformed_source
    source = inspect.getsource(simulation)
    new_source, _ = get_tracing_transformed_source(source)
    print(f"Transformed simulation:\n{new_source}")

    # Measure simulation time (without plotting)
    import time
    solver_timeout_s = 60*60*1  # 1 hour
    start_t = time.time()
    log, _ = pylic_cma_solver(puzzle, solver_timeout_s, 12, 1)
    end_t = time.time()
    # Find best action
    def score(a) -> float:
        timestep_n = int(FPS_PER_ACTION*(len(a)+1))
        pressed_buttons = simulation(a, puzzle, None, timestep_n)
        score = len(pressed_buttons)/max(1, get_puzzle_button_n(puzzle))
        return -score
    actions = sorted(log, key=lambda a: score(a[0]))[0][0]
    timestep_n = int(FPS_PER_ACTION*(len(actions)+1))
    pressed_buttons = simulation(actions, puzzle, None, timestep_n)
    print(f"Solver time: {end_t-start_t}s")
    print(f"Action number: {len(actions)}")
    print(f"Fraction of buttons pressed: {-score(actions)}")

    # Plot simulation
    animation_path = Path()/"debugging_marble_drop.mp4"
    simulation(actions, puzzle, animation_path, timestep_n)
    print(f"Wrote {animation_path}")
