"""
This file implements the EPO algorithm. See the `epo` function for the main entrypoint.
"""

import contextlib
import dataclasses
import time
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.distributions
import transformers
from custom_dreamy.callbacks import CombinedCallback, ICallback, ParetoCallback
from custom_dreamy.history import History
from custom_dreamy.i_runner import IRunner
from custom_dreamy.selector import GradientSelector
from custom_dreamy.state import State


@contextlib.contextmanager
def add_fwd_hooks(module_hooks: List[Tuple[torch.nn.Module, Callable]]):
    """
    Context manager for temporarily adding forward hooks to a model.

    Parameters
    ----------
    module_hooks
        A list of pairs: (module, fnc) The function will be registered as a
            forward hook on the module
    """
    try:
        handles = []
        for mod, hk in module_hooks:
            handles.append(mod.register_forward_hook(hk))
        yield
    finally:
        for h in handles:
            h.remove()


@torch.no_grad()
def epo(
    cache_run: IRunner,
    model: torch.nn.Module,
    tokenizer: transformers.PreTrainedTokenizerBase,
    seq_len: int = 12,
    population_size: int = 8,
    iters: int = 300,
    explore_per_pop: int = 32,
    batch_size: int = 256,
    topk: int = 512,
    mutation_method: str = "gradient",
    x_penalty_min: float = 1.0 / 10.0,
    x_penalty_max: float = 10.0,
    restart_frequency: Optional[int] = 50,
    restart_xentropy: float = 2.0,
    restart_xentropy_max_mult: float = 3.0,
    seed: int = None,
    initial_ids: torch.Tensor = None,
    fixed_positions: Optional[List[bool]] = None,
    history: History = None,
    catch_keyboard_interrupt: bool = False,
    callbacks: List[ICallback] = [],
    verbose: bool = True,
    always_recompute_gradients: bool = False,
    device: str = "cuda",
    override_vocab_size: int = None,
) -> History:
    """
    Run the EPO algorithm. See the paper for details.

    Parameters
    ----------
    cache_run
        A IRunner that handles running the model and getting the target and logits
    model
    tokenizer
    seq_len, optional
        The number of tokens in the optimized prompt, by default 16
    population_size, optional
        The population to keep at each iteration, by default 32
    iters, optional
        Number of iterations to run EPO, by default 1000
    explore_per_pop, optional
        Number of children per population member per iteration, by default 4
    batch_size, optional
        GPU batch size, by default 8
    topk, optional
        When selecting token replacements, we select the `topk` tokens by
        gradient magnitude and choose uniformly at random between those, by
        default 32.
    mutation_method, optional
        research, ignore, by default "gradient"
    x_penalty_min, optional
        The minimum cross-entropy penalty, by default 1.0/16.0
    x_penalty_max, optional
        The maximum cross-entropy penalty, by default 16.0
    restart_frequency, optional
        How often do we reset the Pareto frontier, by default 50
    restart_xentropy, optional
        When we reset the Pareto frontier, we select a population member that
        is optimal according to a cross-entropy penalty that is selected
        uniformly at random in the domain
        [restart_xentropy / restart_xentropy_max_mult,
         restart_xentropy * restart_xentropy_max_mult],
        restart_xentropy is by default 2.0
    restart_xentropy_max_mult, optional
        See the explanation for restart_xentropy, by default 3.0
    seed, optional
        Random seed used for initialization, by default 0
    initial_ids, optional
        The initial token ids to begin optimizing from. If None, the initial
        token ids will be selected randomly, by default None
    fixed_positions: List[bool], optional
        List of bools, True means the corresponding token should be fixed, False means it
        can be modified.
    history, optional
        The history of an EPO run that we want to continue, by default None
    catch_keyboard_interrupt, optional
        Should we catch keyboard interrupts and end the EPO loop?, by default False
    callbacks, optional
        A list of ICallbacks objects called at the beginning of each iteration, by default empty
    verbose : bool, optional
        If True, will print pareto frontier updates during training, by default True
    always_recompute_gradients, optional
        If a population member is retained across an iteration, we default to
        not recomputing that population member's token gradients. If your
        cache_run stores internal state that changes, you may want to override
        this behavior and recompute gradients every iteration.
    device, optional
        Device to run the model on, by default "cuda"
    Returns
    -------
        A History object containing the full history of the

    Raises
    ------
    ValueError
        _description_
    ValueError
        _description_
    ValueError
        _description_
    """
    model = model.to(device)

    for name, param in model.named_parameters():
        param.requires_grad_(False)



    start = time.time()
    explore_size = population_size * explore_per_pop
    vocab_size = tokenizer.vocab_size
    if override_vocab_size is not None:
        vocab_size = override_vocab_size

    if seed is not None:
        torch.manual_seed(seed)

    if x_penalty_min is None or x_penalty_max is None:
        X = torch.zeros(population_size, device=device)
    else:
        X = torch.exp(
            torch.linspace(
                np.log(x_penalty_min), np.log(x_penalty_max), population_size
            )
        ).to(device)

    # Add pareto callback if verbose
    if verbose:
        pareto_cb = ParetoCallback(
            cache_run,
            model,
            tokenizer,
            x_penalty_min if x_penalty_min is not None else 0.1,
            x_penalty_max if x_penalty_max is not None else 10.0,
            fixed_positions,
        )
        callbacks.append(pareto_cb)

    # Combine callbacks or use None if no callbacks
    if callbacks:
        callback = CombinedCallback(callbacks)
    else:
        callback = None

    #### history and initial_ids ####
    if history is not None:
        if initial_ids is not None:
            raise ValueError("Cannot specify both history and initial_ids.")
        input_ids = history.ids[-1, history.keep[-1]]
    elif initial_ids is not None:
        history = History(pop_size=population_size, explore_per_pop=explore_per_pop)
        input_ids = initial_ids.to(device)
        if initial_ids.shape[1] != seq_len:
            raise ValueError(f"initial_ids must have shape (*, {seq_len})")
    else:
        history = History(pop_size=population_size, explore_per_pop=explore_per_pop)
        input_ids = torch.randint(
            0, tokenizer.vocab_size, (population_size, seq_len)
        ).to(device)

    #### choose a update selection method ####
    if mutation_method == "gradient":
        selector_type = GradientSelector
    else:
        raise ValueError(f"Unknown selection method: {mutation_method}")
    selector = selector_type(
        model, cache_run, X, batch_size, vocab_size, fixed_positions
    )

    #### Run the EPO loop: ####
    if hasattr(cache_run, "setup"):
        cache_run.setup(input_ids)

    if fixed_positions is not None:
        nonfixed_positions = [i for i in range(seq_len) if i not in fixed_positions]

        nonfixed_positions = [i for i, val in enumerate(fixed_positions) if (not val)]
        nonfixed_positions_tensor = torch.tensor(nonfixed_positions, device=device)
    else:
        nonfixed_positions_tensor = None

    state = selector.setup(input_ids)
    # We use a try/except block so that we can catch keyboard interrupts and
    # still return results. This is useful for interactive use when it's nice
    # to launch with a large `iters` parameter and then just stop the run when
    # the results look good enough.
    try:
        for i in range(iters):
            ########################################
            # 1) Report!
            ########################################

            if callback is not None:
                terminate_flag = callback(
                    i, state, time.time() - start, history, selector
                )
            else:
                terminate_flag = False
            if (
                (isinstance(terminate_flag, str) and terminate_flag == "terminate")
                or (isinstance(terminate_flag, torch.Tensor) and terminate_flag.item())
                or (isinstance(terminate_flag, bool) and terminate_flag)
            ):
                if i == 0:
                    history._insert(
                        state.ids,
                        state.target,
                        state.xentropy,
                        torch.arange(state.ids.shape[0]),
                        time.time() - start,
                        X,
                    )
                break
            else:
                start = time.time()
            recompute_gradients = always_recompute_gradients or (
                terminate_flag == "recompute_gradients"
            )

            ########################################
            # 1.5) Modify state
            ########################################

            # 2) Birth children from parents
            # copy inputs to expand out to explore_size new candidates.
            ########################################
            source_idx = torch.cat(
                (
                    torch.arange(state.ids.shape[0], device=device).repeat(
                        explore_size // state.ids.shape[0]
                    ),
                    torch.arange(explore_size % state.ids.shape[0], device=device),
                )
            )
            assert source_idx.shape[0] == explore_size
            assert (source_idx < state.ids.shape[0]).all()

            new_ids = state.ids[source_idx, :].clone()

            ########################################
            # 3) Run the selector. This might be:
            #    - random
            #    - gradient-guided
            #    - cosine-similarity-guided
            ########################################
            selector.mutate(state, source_idx, new_ids, topk)

            ########################################
            # 5) Evaluate fitness
            ########################################
            new_state = evaluate_fitness(
                model,
                cache_run,
                new_ids,
                batch_size=batch_size,
                vocab_size=vocab_size,
                nonfixed_positions=nonfixed_positions_tensor,
            )

            all_state = state.cat(new_state)

            # note that all_loss is a matrix with a row for each population
            # member because each population member slot uses a different
            # xentropy penalty.
            all_loss = (
                -all_state.target[None, :] + X[:, None] * all_state.xentropy[None, :]
            )
            keep = (-all_loss).argmax(dim=1).to(torch.int)
            cur_x_values = X
            if restart_frequency is not None:
                if i % restart_frequency == 0 and i > 0:
                    min_mult = 1.0 / restart_xentropy_max_mult
                    max_mult = restart_xentropy_max_mult
                    mult = min_mult + (max_mult - min_mult) * torch.rand(1).item()
                    restart_X = restart_xentropy * mult
                    restart_loss = (
                        -all_state.target + restart_xentropy * all_state.xentropy
                    )
                    print(f"restarting with xentropy penalty of {restart_X:.2f}")
                    keep[:] = restart_loss.argmin()
                    cur_x_values = torch.full_like(X, restart_X)

            history._insert(
                all_state.ids,
                all_state.target,
                all_state.xentropy,
                keep,
                time.time() - start,
                cur_x_values,
            )

            ########################################
            # 6) Calculate gradients for the next iteration.
            ########################################
            if i != iters - 1:
                if selector.uses_gradient:
                    # Create new state with members in the exact order specified by keep
                    if recompute_gradients:
                        # Recompute gradients for all kept members
                        state = selector.setup(all_state.ids[keep])
                    else:
                        # Process both survived and new members while preserving original order in keep
                        needs_setup = keep >= state.ids.shape[0]

                        # For members that need gradient computation
                        new_indices = keep[needs_setup]
                        if new_indices.shape[0] > 0:
                            state_new = selector.setup(all_state.ids[new_indices])

                        # For members that can reuse gradients
                        survived_indices = keep[~needs_setup]
                        if survived_indices.shape[0] > 0:
                            # Map from all_state indices to state indices for survived members
                            state_survived = state.subset(survived_indices)

                        # Combine in the original keep order
                        final_state_parts = []
                        new_counter = 0
                        survived_counter = 0

                        for idx in range(len(keep)):
                            if needs_setup[idx]:
                                # This is a new member that needed setup
                                final_state_parts.append(
                                    state_new.subset(torch.tensor([new_counter]))
                                )
                                new_counter += 1
                            else:
                                # This is a survived member
                                final_state_parts.append(
                                    state_survived.subset(
                                        torch.tensor([survived_counter])
                                    )
                                )
                                survived_counter += 1

                        # Combine all parts in the original keep order
                        state = final_state_parts[0]
                        for part in final_state_parts[1:]:
                            state = state.cat(part)
                else:
                    state = all_state.subset(keep)

    # it's handy to sometimes be able to interrupt the loop and still get
    # results!
    except KeyboardInterrupt:
        if catch_keyboard_interrupt:
            pass
        else:
            raise
    if callback is not None:
        terminate_flag = callback(
            i, state, time.time() - start, history, selector, final=True
        )

    history._finalize()

    return history


@dataclasses.dataclass
class ParetoFrontier:
    # the range of cross-entropy penalties used
    Xvs: np.ndarray
    # the target and xentropy values for each penalty level
    full_target: np.ndarray
    full_xentropy: np.ndarray
    # the unique indices in full_target/full_xentropy that make up the pareto frontier.
    unique: np.ndarray
    # the target and xentropy values for the unique entries
    target: np.ndarray
    xentropy: np.ndarray
    # the token ids for each unique point on the frontier.
    ids: np.ndarray
    # the detokenized text for each unique point on the frontier.
    text: List[str]


def build_pareto_frontier(tokenizer, histories, Xvs=None):
    """
    Construct a pareto frontier from the history of several EPO runs. We allow
    multiple histories to be passed so that we can construct the Pareto
    frontier across several different runs of EPO with different random
    initializations.

    Parameters
    ----------
    tokenizer
    histories
        A list of History objects returned by the EPO algorithm. We allow
        multiple independent histories to be combined
    Xvs, optional
        The range of cross-entropy penalties to use.
        By default Xvs = 1.0 / np.linspace(0, 50, 1000)[1:]

    Returns
    -------
        A ParetoFrontier object.
    """

    if Xvs is None:
        Xvs = 1.0 / np.linspace(0, 50, 1000)[1:]

    if not isinstance(histories, list):
        histories = [histories]
    x = []
    t = []
    ids = []
    for h in histories:
        x.append(h.xentropy.flatten())
        t.append(h.target.flatten())
        ids.append(h.ids.reshape((-1, h.ids.shape[-1])))

    history_x = np.concatenate(x)
    history_t = np.concatenate(t)
    history_ids = np.concatenate(ids, axis=0)
    pareto_t = np.empty(Xvs.shape[0])
    pareto_x = np.empty(Xvs.shape[0])
    pareto_idxs = []
    for i, Xv in enumerate(Xvs):
        loss = -history_t + Xv * history_x
        idx = loss.argmin()
        pareto_idxs.append(idx)
        pareto_t[i] = history_t[idx]
        pareto_x[i] = history_x[idx]
    pareto_unique = np.unique(pareto_idxs, return_index=True)[1]
    pareto_ids = [history_ids[pareto_idxs[i]] for i in pareto_unique]
    pareto_text = [tokenizer.decode(ids) for ids in pareto_ids]
    return ParetoFrontier(
        np.array(Xvs),
        pareto_t,
        pareto_x,
        pareto_unique,
        pareto_t[pareto_unique],
        pareto_x[pareto_unique],
        pareto_ids,
        pareto_text,
    )


def gcg(
    cache_run: IRunner,
    model: torch.nn.Module,
    tokenizer: transformers.PreTrainedTokenizer,
    seq_len: int = 16,
    iters: int = 1000,
    batch_size: int = 8,
    topk: int = 32,
    x_penalty_min: float = 1.0 / 16.0,
    x_penalty_max: float = 16.0,
    seed: int = 0,
    initial_ids: torch.Tensor = None,
    history: History = None,
    catch_keyboard_interrupt: bool = False,
    callback: Union[Callable, bool] = None,
    always_recompute_gradients: bool = False,
):
    """GCG is a special case of EPO where the population size is 1."""
    epo(
        cache_run,
        model,
        tokenizer,
        seq_len=seq_len,
        population_size=1,
        iters=iters,
        explore_per_pop=batch_size,
        batch_size=batch_size,
        topk=topk,
        mutation_method="gradient",
        x_penalty_min=x_penalty_min,
        x_penalty_max=x_penalty_max,
        seed=seed,
        initial_ids=initial_ids,
        history=history,
        catch_keyboard_interrupt=catch_keyboard_interrupt,
        callback=callback,
        always_recompute_gradients=always_recompute_gradients,
    )


########################################
# Private implementation details below here.
########################################


def evaluate_fitness(
    model: torch.nn.Module,
    cache_run: IRunner,
    input_ids: torch.Tensor,
    batch_size: int,
    vocab_size: int,
    nonfixed_positions: Optional[torch.Tensor] = None,
):
    target = torch.empty(input_ids.shape[0], dtype=torch.float, device=input_ids.device)
    xentropy = torch.empty(
        input_ids.shape[0], dtype=torch.float, device=input_ids.device
    )
    final_token = torch.empty(
        input_ids.shape[0], dtype=torch.long, device=input_ids.device
    )
    extra = dict()
    for i in range(0, input_ids.shape[0], batch_size):
        imax = min(i + batch_size, input_ids.shape[0])

        mini_targets, mini_logits, mini_extras = cache_run.run_with_embeddings(
            cache_run.int_ids_to_embed(input_ids[i:imax])
        )

        target[i:imax] = mini_targets
        xentropy[i:imax] = cache_run.calc_xentropy(
            mini_logits, input_ids[i:imax], nonfixed_positions
        )
        final_token[i:imax] = mini_logits[:, -1, :].argmax(dim=-1)

        for k in mini_extras:
            e = mini_extras[k]
            if k not in extra:
                extra[k] = torch.empty(
                    (input_ids.shape[0], *e.shape[1:]),
                    dtype=e.dtype,
                    device=e.device,
                )
            extra[k][i:imax] = e

    return State(input_ids, target, xentropy, final_token, None, extra)
