from typing import Optional
from dataclasses import dataclass

import torch
import torch.nn as nn
from safetensors.torch import save_file

from .embedding import DtiTokenEmbedding


@dataclass
class NewToken:
    placeholder: str
    identifier: str
    initializer: str
    token_ids: list[int]
    num_vectors: int


def replace_token_embedding(
    model: nn.Module, reparameterize: bool = True, legacy: bool = False
) -> nn.Module:
    old_embedding = model.get_input_embeddings()
    new_embedding = DtiTokenEmbedding(
        old_embedding.num_embeddings,
        old_embedding.embedding_dim,
        padding_idx=old_embedding.padding_idx,
        reparameterize=reparameterize,
        legacy=legacy,
    ).to(old_embedding.weight.device)
    new_embedding.weight.data[: old_embedding.num_embeddings] = (
        old_embedding.weight.data
    )
    if hasattr(old_embedding, "scales"):
        new_embedding.scales.weight.data[: old_embedding.num_embeddings] = (
            old_embedding.scales.weight.data
        )
    model.set_input_embeddings(new_embedding)

    if not torch.equal(old_embedding.weight, model.get_input_embeddings().weight):
        raise RuntimeError("Embedding weights are not exactly the same after resizing.")

    return model


@torch.no_grad()
def add_new_token(
    tokenizer,
    text_encoder,
    placeholder: str,
    num_vectors: int,
    scale: str = "max",
    init_token: Optional[torch.Tensor | str] = None,
    init_method: str = "random",
    connector: str = "",
) -> tuple[NewToken, torch.Tensor]:
    if isinstance(init_token, str):
        init_token_ids = tokenizer.encode(init_token, add_special_tokens=False)
        num_vectors = len(init_token_ids)
    elif isinstance(init_token, torch.Tensor):
        init_token_ids = None
        num_vectors = init_token.shape[0]

    placeholder_tokens = [placeholder]
    additional_tokens = []
    for i in range(1, num_vectors):
        additional_tokens.append(f"{placeholder}_{i}")
    placeholder_tokens += additional_tokens

    num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
    if num_added_tokens != num_vectors:
        raise ValueError(
            f"The tokenizer already contains one of the tokens {placeholder_tokens}. Please pass a different `placeholder`."
        )
    placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)

    # Resize the token embeddings as we are adding new special tokens to the tokenizer
    text_encoder.resize_token_embeddings(len(tokenizer))
    embeddings = text_encoder.get_input_embeddings()
    embeddings.resize_token_scales(len(tokenizer))

    # Initialise the newly added placeholder token with the embeddings of the initializer token
    token_embeds = embeddings.weight.data
    token_scales = embeddings.scales.weight.data
    if scale == "min":
        scale = embeddings.get_min_scale()
    elif scale == "mean":
        scale = embeddings.get_mean_scale()
    elif scale == "max":
        scale = embeddings.get_max_scale()
    else:
        try:
            scale = float(scale)
        except Exception:
            raise ValueError(
                f"Unknown scale: {scale}. Please use 'min, 'mean', 'max' or a float value."
            )

    init_embeds = []
    for i in range(num_vectors):
        token_id = placeholder_token_ids[i]
        if init_token is None:
            init_embed = torch.randn_like(token_embeds[0])
        elif init_token_ids is not None:
            init_token_id = init_token_ids[i]
            init_embed = token_embeds[init_token_id].clone()
        else:
            init_embed = init_token[i].clone()
        init_embeds.append(init_embed)

        if init_method == "token":
            new_embed = init_embed / torch.norm(init_embed, dim=-1, keepdim=True)
        elif init_method == "random":
            new_embed = torch.randn_like(init_embed)
            new_embed = new_embed / torch.norm(new_embed, dim=-1, keepdim=True)
        elif init_method == "mean":
            # valid_embeds = token_embeds[:49406]
            valid_embeds = token_embeds[: embeddings.original_vocab_size]
            # remove zero embeddings
            valid_embeds = valid_embeds[torch.norm(valid_embeds, dim=-1) > 0.0]
            new_embed = torch.mean(valid_embeds, dim=0)
            new_embed = new_embed / torch.norm(new_embed, dim=-1, keepdim=True)
        else:
            raise ValueError(f"Unknown init method: {init_method}.")
        token_embeds[token_id] = new_embed

        # scale_id = token_id - embeddings.original_vocab_size
        # token_scales[scale_id] = scale
        token_scales[token_id] = scale

    init_embeds = torch.stack(init_embeds, dim=0)
    new_token = NewToken(
        placeholder=placeholder,
        identifier=connector.join(placeholder_tokens),
        initializer=init_token if isinstance(init_token, str) else "",
        token_ids=placeholder_token_ids,
        num_vectors=num_vectors,
    )
    return new_token, init_embeds


@torch.no_grad()
def save_progress(
    text_encoder,
    new_tokens: list[NewToken],
    save_path: str,
    safe_serialization: bool = True,
) -> None:
    learned_embeds_dict = {}
    for new_token in new_tokens:
        placeholder = new_token.placeholder
        embeddings = text_encoder.get_input_embeddings()

        learned_embeds = (
            embeddings.weight[min(new_token.token_ids) : max(new_token.token_ids) + 1]
            .detach()
            .float()
        )

        if isinstance(embeddings, DtiTokenEmbedding):
            learned_scales = embeddings.scales.weight[
                min(new_token.token_ids) : max(new_token.token_ids) + 1
            ]
            learned_scales = learned_scales.detach().float()
            norms = torch.norm(learned_embeds, dim=-1, keepdim=True)
            norms = torch.max(norms, torch.tensor(1e-6, device=norms.device))
            learned_embeds = learned_scales * (learned_embeds / norms)

        learned_embeds_dict[placeholder] = learned_embeds

    if safe_serialization:
        save_file(learned_embeds_dict, save_path, metadata={"format": "pt"})
    else:
        torch.save(learned_embeds_dict, save_path)


def project_to_tangent_space(
    embeds: torch.Tensor,
    grads: torch.Tensor,
) -> torch.Tensor:
    # Project the gradient onto the tangent space of the unit sphere at the point of the embedding.
    # See https://en.wikipedia.org/wiki/Tangent_space#Tangent_space_of_a_sphere
    # Assume embeds are normalized.
    # Shape of embeds and grads: (N, D)
    grad_proj = (embeds * grads).sum(dim=-1, keepdim=True)  # (N, 1)

    # Project: g_i - (g_i . e_i)e_i for each i.
    return grads - grad_proj * embeds


@torch.no_grad()
def project_grads_to_tangent_space(
    text_encoder: nn.Module,
    added_token_ids: list[int],
    *,
    kappa: float = 0.0,
    target_embeds: Optional[torch.Tensor] = None,
) -> None:
    embeds = text_encoder.get_input_embeddings().weight[
        min(added_token_ids) : max(added_token_ids) + 1
    ]
    grad = text_encoder.get_input_embeddings().weight.grad[
        min(added_token_ids) : max(added_token_ids) + 1
    ]

    # (Optional) Apply cosine similarity regularization.
    if kappa > 0.0 and target_embeds is not None:
        prior_grad = (kappa / 1000) * target_embeds  # - \kappa * \mu
        grad = grad - prior_grad

    # Project the gradient onto the tangent space.
    grad_proj = (embeds * grad).sum(dim=-1, keepdim=True)
    grad = grad - grad_proj * embeds

    # Normalize the gradient to have norm 1.
    grad_norm = torch.linalg.norm(grad, dim=-1, keepdim=True)  # L2 norm
    grad_norm = torch.clamp(grad_norm, min=1e-6)
    grad = grad / grad_norm

    # Adaptive gradient normalization.
    # h_1 = grad.square().sum(dim=1)
    # v_t1 = v_t1 * b2 + (1-b2) * h_1
    # grad = grad / torch.sqrt(v_t1 + 1e-8)

    # Write back the projected gradient.
    text_encoder.get_input_embeddings().weight.grad[
        min(added_token_ids) : max(added_token_ids) + 1
    ] = grad


def retract_token_embeddings(
    embeddings: nn.Module,
    index_updates: torch.BoolTensor,
) -> None:
    v = embeddings.weight[index_updates].clone()
    v = v / torch.linalg.norm(v, dim=-1, keepdim=True)
    embeddings.weight[index_updates] = v
