from typing import List, Optional

import torch
import torch.nn.functional as F
from custom_dreamy.i_runner import IRunner
from custom_dreamy.state import State


# based on https://github.com/llm-attacks/llm-attacks/blob/main/llm_attacks/gcg/gcg_attack.py
def token_grads(
    model: torch.nn.Module,
    cache_run: IRunner,
    input_ids: torch.Tensor,
    x_penalty: torch.Tensor,
    batch_size: int,
    vocab_size: int,
    fixed_positions: Optional[List[int]] = None,
):
    """
    Compute gradients with respect to one-hot encoded input tokens. This is a
    infinitesimal approximation to the token influence on the loss so it's a
    very noisy indicator of which tokens might reduce loss.
    """

    token_grads = torch.empty(
        (input_ids.shape[0], input_ids.shape[1], vocab_size),
        dtype=torch.float,
        device=input_ids.device,
    )
    loss = torch.empty(input_ids.shape[0], device=input_ids.device)
    xentropy = torch.empty(input_ids.shape[0], device=input_ids.device)
    target = torch.empty(input_ids.shape[0], device=input_ids.device)
    final_token = torch.empty(
        input_ids.shape[0], device=input_ids.device, dtype=torch.long
    )

    if fixed_positions is not None:
        valid_positions_offset = [
            i - 1 for i, val in enumerate(fixed_positions) if (not val and (i != 0))
        ]
        valid_positions_offset = torch.tensor(
            valid_positions_offset, device=input_ids.device
        )
    else:
        valid_positions_offset = torch.arange(
            input_ids.shape[-1] - 1, device=input_ids.device
        )

    extra = dict()

    with torch.enable_grad():
        model.zero_grad()

        for i in range(0, input_ids.shape[0], batch_size):
            imax = min(i + batch_size, input_ids.shape[0])

            # using a one hot matrix as input to the model gives us gradients with
            # respect to potential input tokens.
            cur_input_ids = input_ids[i:imax].clone()
            one_hot = F.one_hot(cur_input_ids, num_classes=vocab_size).to(
                torch.bfloat16  # TODO: make this same as model
            )
            one_hot.requires_grad = True

            target, logits, extras = cache_run.run_with_embeddings(
                cache_run.one_hot_to_embed(one_hot)
            )

            logits_offset = logits[:, :-1]

            this_xentropy = (
                -(
                    torch.log_softmax(logits_offset, dim=-1)[:, valid_positions_offset]
                    * one_hot[:, 1:][:, valid_positions_offset]
                )
                .sum(dim=-1)
                .mean(dim=-1)
            )

            this_loss = -target + this_xentropy * x_penalty[i:imax]
            this_loss.sum().backward()

            loss[i:imax] = this_loss
            target[i:imax] = target
            xentropy[i:imax] = this_xentropy
            final_token[i:imax] = logits[:, -1, :].argmax(dim=-1)
            token_grads[i:imax] = one_hot.grad

            for k in extras:
                e = extras[k]
                if k not in extra:
                    extra[k] = torch.empty(
                        (input_ids.shape[0], *e.shape[1:]),
                        dtype=e.dtype,
                        device=e.device,
                    )
                extra[k][i:imax] = e

            # important to zero out gradients here to release memory
            model.zero_grad()

    return State(input_ids, target, xentropy, final_token, token_grads, extra)


class Selector:
    def __init__(
        self,
        model: torch.nn.Module,
        cache_run: IRunner,
        X: torch.Tensor,
        batch_size: int,
        vocab_size: int,
        fixed_positions: Optional[List[int]] = None,
    ):
        self.model = model
        self.cache_run = cache_run
        self.X = X
        self.batch_size = batch_size
        self.fixed_positions = fixed_positions
        self.vocab_size = vocab_size

    def setup(self, input_ids: torch.Tensor):
        return token_grads(
            self.model,
            self.cache_run,
            input_ids,
            x_penalty=self.X[: input_ids.shape[0]],
            batch_size=self.batch_size,
            vocab_size=self.vocab_size,
            fixed_positions=self.fixed_positions,
        )

    def mutate(self, state, source_idx, input_ids, topk):
        # Get topk gradients
        topk_grad = (-state.token_grads).topk(k=topk, dim=-1)

        # Generate random positions, excluding fixed positions
        if self.fixed_positions is None:
            valid_positions = [i for i in range(input_ids.shape[1])]
        else:
            valid_positions = [
                i for i, val in enumerate(self.fixed_positions) if not val
            ]
        if not valid_positions:
            return  # All positions are fixed, no mutation possible

        pos = torch.tensor(
            [
                valid_positions[i]
                for i in torch.randint(
                    low=0,
                    high=len(valid_positions),
                    size=(input_ids.shape[0],),
                )
            ],
            device=input_ids.device,
        )

        token_idx = torch.randint(
            low=0,
            high=topk,
            size=(input_ids.shape[0],),
            device=input_ids.device,
        )

        # Only modify non-fixed positions
        input_ids[torch.arange(input_ids.shape[0]), pos] = topk_grad.indices.to(
            input_ids.device
        )[source_idx, pos, token_idx]


class GradientSelector(Selector):
    uses_gradient = True

    def __init__(
        self, model, cache_run, X, batch_size, vocab_size, fixed_positions=None
    ):
        super().__init__(model, cache_run, X, batch_size, vocab_size, fixed_positions)

    def mutate(self, state, source_idx, input_ids, topk):
        # Get topk gradients
        topk_grad = (-state.token_grads).topk(k=topk, dim=-1)

        # Generate random positions, excluding fixed positions

        if self.fixed_positions is None:
            valid_positions = [i for i in range(input_ids.shape[1])]
        else:
            valid_positions = [
                i for i, val in enumerate(self.fixed_positions) if not val
            ]
        if not valid_positions:
            return  # All positions are fixed, no mutation possible

        pos = torch.tensor(
            [
                valid_positions[i]
                for i in torch.randint(
                    low=0,
                    high=len(valid_positions),
                    size=(input_ids.shape[0],),
                )
            ],
            device=input_ids.device,
        )

        token_idx = torch.randint(
            low=0,
            high=topk,
            size=(input_ids.shape[0],),
            device=input_ids.device,
        )

        # Only modify non-fixed positions
        input_ids[torch.arange(input_ids.shape[0]), pos] = topk_grad.indices.to(
            input_ids.device
        )[source_idx, pos, token_idx]
