import torch
import torch.nn.functional as F
import sys
import os
from tqdm import tqdm

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from utils.stable_diffusion import load_text_components
from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask


@torch.no_grad()
def get_tokens_from_token_embeddings(token_embeddings, text_encoder):
    position_ids = text_encoder.text_model.embeddings.position_ids
    position_embeddings = text_encoder.text_model.embeddings.position_embedding(
        position_ids
    )
    token_embeddings = token_embeddings - position_embeddings
    token_embedding_weights = text_encoder.text_model.embeddings.token_embedding.weight

    similarity = torch.matmul(
        F.normalize(token_embeddings, p=2, dim=-1),
        F.normalize(token_embedding_weights, p=2, dim=-1).T,
    )

    tokens = similarity.argmax(dim=-1)

    return tokens


def get_text_embeddings_from_token_embeddings(token_embeddings, text_encoder):
    causal_attention_mask = _create_4d_causal_attention_mask(
        token_embeddings.shape[:2],
        token_embeddings.dtype,
        device=token_embeddings.device,
    )
    encoder_output = text_encoder.text_model.encoder(
        token_embeddings,
        causal_attention_mask=causal_attention_mask,
        output_attentions=None,
        output_hidden_states=None,
    )
    last_hidden_state = encoder_output[0]
    last_hidden_state = text_encoder.text_model.final_layer_norm(last_hidden_state)
    return last_hidden_state


def get_text_embeddings_from_tokens(tokens, text_encoder):
    return text_encoder(tokens)[0]


def gumbel_search(
    model,
    seq_len: int = 77,
    vocab_size: int = 49408,
    device: str = "cuda",
    batch_size: int = 8,
    num_steps: int = 2000,
    tau_start: float = 1.0,
    tau_end: float = 0.1,
    lr: float = 0.1,
    is_max: bool = True,
):
    model = model.to(device).eval()

    z = torch.randn(1, seq_len, vocab_size, device=device, requires_grad=True)
    optimizer = torch.optim.Adam([z], lr=lr)

    best_scores = (
        torch.full((batch_size,), -float("inf"), device=device)
        if is_max
        else torch.full((batch_size,), float("inf"), device=device)
    )
    best_tokens = torch.zeros(batch_size, seq_len, dtype=torch.long, device=device)

    pb = tqdm(range(num_steps))

    for step in pb:
        tau = tau_start * ((tau_end / tau_start) ** (step / (num_steps - 1)))
        probs = torch.cat(
            [
                F.gumbel_softmax(z, tau=tau, hard=False, dim=-1)
                for _ in range(batch_size)
            ]
        )  # [B, L, V]

        token_embeds = (
            probs @ model.text_model.embeddings.token_embedding.weight
        )  # [B, L, D]
        pos_embeds = model.text_model.embeddings.position_embedding.weight.unsqueeze(
            0
        )  # [1, L, D]
        token_embeddings = token_embeds + pos_embeds  # [B, L, D]

        text_embeddings = get_text_embeddings_from_token_embeddings(
            token_embeddings, text_encoder
        )
        scores = text_embeddings.norm(dim=(1, 2))  # [B]

        mask = scores > best_scores if is_max else scores < best_scores
        tokens = get_tokens_from_token_embeddings(token_embeddings, text_encoder)
        with torch.no_grad():
            text_embeddings = get_text_embeddings_from_tokens(tokens, text_encoder)
            real_score = text_embeddings.norm(dim=(1, 2)).mean()

        best_scores[mask] = scores[mask]
        best_tokens[mask] = tokens[mask]

        loss = -scores.mean() if is_max else scores.mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        norm = z.grad.norm(dim=-1).mean().item() if z.grad is not None else 0

        pb.set_postfix(
            loss=loss.item(),
            norm=f"{norm:.2f}",
            real_score=f"{real_score:.2f}",
        )

    return best_tokens, best_scores


if __name__ == "__main__":
    tokenizer, text_encoder = load_text_components("v1-4")
    best_tokens, best_scores = gumbel_search(text_encoder)
    print(best_tokens, best_scores)

    best_tokens, best_scores = gumbel_search(text_encoder, is_max=False)
    print(best_tokens, best_scores)
