import socket
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from safetensors.torch import load_file


def find_free_port() -> int:
    """Find a free port on localhost."""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("", 0))
        return s.getsockname()[1]


def str2bool(v):
    if isinstance(v, bool):
        return v
    elif isinstance(v, str):
        return v.lower() in ("yes", "true", "t", "1")
    elif isinstance(v, int):
        return v == 1
    else:
        raise ValueError(
            f"Invalid value for str2bool: {v}. Expected a boolean, string, or integer."
        )


def load_embedding(
    tokenizer, text_encoder, state_dict, placeholder=None, connector: str = ""
) -> list[str]:
    assert Path(state_dict).exists(), f"File not found: {state_dict}"
    if not isinstance(state_dict, dict):
        state_dict = load_file(state_dict)

    # Load multiple token embeddings.
    identifiers = []
    for key, embs in state_dict.items():
        if placeholder is not None:
            key = placeholder
        tokens = [key]
        for i in range(1, embs.size(0)):
            tokens.append(f"{key}_{i}")
        tokenizer.add_tokens(tokens)
        text_encoder.resize_token_embeddings(len(tokenizer))
        token_ids = tokenizer.convert_tokens_to_ids(tokens)
        embeddings = text_encoder.get_input_embeddings().weight.data
        for id, emb in zip(token_ids, embs):
            embeddings[id] = emb.clone()
        identifier = connector.join(tokens)
        identifiers.append(identifier)
    return identifiers


def generate_exponential_samples_in_range(scale, size=100, low=0, high=20):
    """
    Generates samples from an exponential distribution, and then clips them to fall
    within a specified range.

    Parameters:
    scale (float): The scale parameter (beta) of the exponential distribution.
                   This is the inverse of the rate parameter (lambda).
    size (int): The number of samples to generate. Default is 100.
    low (int or float): The minimum value of the range. Default is 1.
    high (int or float): The maximum value of the range. Default is 10.

    Returns:
    numpy.ndarray: Array of samples from the clipped exponential distribution.
                   Values are guaranteed to be within the range [low, high].
    """
    if scale <= 0:
        raise ValueError("Scale parameter (scale) must be positive.")
    if low >= high:
        raise ValueError("low must be strictly less than high.")

    # Generate samples from the exponential distribution.
    samples = np.random.exponential(scale=scale, size=size)

    # Clip the samples to the specified range.
    clipped_samples = np.clip(samples, low, high)
    return np.floor(clipped_samples).astype(int)


def compute_kappa(R_bar, p):
    return (R_bar * p - R_bar**3) / (1 - R_bar**2)


@torch.no_grad()
def get_close_words(
    target,
    tokenizer,
    text_encoder,
    n=20,
    distance="cosine",
):
    if isinstance(target, str):
        token_id = tokenizer.encode(target, add_special_tokens=False)
        if len(token_id) > 1:
            print(token_id)
            return "Only single tokens are supported."
        token_id = token_id[0]
    else:
        token_id = target

    embeds = text_encoder.get_input_embeddings().weight.data

    if distance == "l2":
        embeds /= embeds.norm(dim=-1, keepdim=True)
        l2 = F.pairwise_distance(embeds, embeds[token_id].unsqueeze(0), p=2)
        l2 = l2.cpu().numpy()
        topk = l2.argsort()[1 : n + 1]
    elif distance == "cosine":
        cos_sim = F.cosine_similarity(embeds, embeds[token_id].unsqueeze(0), dim=-1)
        cos_sim = cos_sim.cpu().numpy()
        cos_dist = 1 - cos_sim
        topk = cos_dist.argsort()[1 : n + 1]
    else:
        raise ValueError("Distance metric not supported.")
    return topk, tokenizer.convert_ids_to_tokens(topk)


def estimate_kappa(
    target,
    tokenizer,
    text_encoder,
    n=500,
    as_tensor=False,
):
    token_id = tokenizer.encode(target, add_special_tokens=False)
    token_embedding = (
        text_encoder.get_input_embeddings()
        .weight.detach()
        .cpu()
        .numpy()
        .astype(np.float32)
    )

    mus = []
    kappas = []
    for id in token_id:
        close_ids, _ = get_close_words(id, tokenizer, text_encoder, n=n)

        x = token_embedding[close_ids]
        x = x / np.linalg.norm(x, axis=1, keepdims=True)
        mean = x.mean(axis=0)
        norm = np.linalg.norm(mean)
        R_bar = norm
        kappa = compute_kappa(R_bar, token_embedding.shape[1])

        mus.append(mean / norm)
        kappas.append(kappa)

    mu = np.asarray(mus)
    if as_tensor:
        mu = torch.as_tensor(mu)
    kappa = np.asarray(kappas)
    if as_tensor:
        kappa = torch.as_tensor(kappa)
    return mu, kappa
