"""
Embedding generation utilities.
"""

import logging
import os
from typing import List, Dict, Tuple, Optional
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

logger = logging.getLogger(__name__)


class SimpleTextDataset(Dataset):
    """Simple dataset for encoding texts."""

    def __init__(self, ids: List[str], texts: List[str]):
        self.ids = ids
        self.texts = texts

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        return {
            "id": self.ids[idx],
            "text": self.texts[idx],
        }


def generate_embeddings(
    model: torch.nn.Module,
    ids: List[str],
    id_to_text: Dict[str, str],
    tokenizer,
    max_length: int,
    batch_size: int,
    device: torch.device,
    is_query: bool = True,
    show_progress: bool = True,
    prefix: str = "",
) -> Tuple[np.ndarray, List[str]]:
    """
    Generate embeddings for a list of texts.

    Args:
        model: The dual encoder model
        ids: List of IDs (query IDs or doc IDs)
        id_to_text: Mapping from ID to text
        tokenizer: HuggingFace tokenizer
        max_length: Maximum sequence length
        batch_size: Batch size for encoding
        device: Device to run on
        is_query: Whether encoding queries (True) or documents (False)
        show_progress: Whether to show progress bar
        prefix: Prefix to prepend to texts (e.g., "query: " or "passage: ")

    Returns:
        Tuple of (embeddings array, ordered list of IDs)
        embeddings: (num_items, embedding_dim)
        ids: List of IDs in the same order as embeddings
    """
    model.eval()

    # Get texts in order
    texts = [id_to_text[id_] for id_ in ids]

    # Create dataset and dataloader
    dataset = SimpleTextDataset(ids, texts)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
    )

    # Generate embeddings
    all_embeddings = []
    all_ids = []

    with torch.no_grad():
        iterator = (
            tqdm(dataloader, desc=f"Encoding {'queries' if is_query else 'documents'}")
            if show_progress
            else dataloader
        )

        for batch in iterator:
            batch_ids = batch["id"]
            batch_texts = batch["text"]

            # Apply prefix if specified
            if prefix:
                batch_texts = [prefix + text for text in batch_texts]

            # Tokenize
            encoded = tokenizer(
                batch_texts,
                max_length=max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            )

            # Move to device
            input_ids = encoded["input_ids"].to(device)
            attention_mask = encoded["attention_mask"].to(device)

            # Encode
            if is_query:
                embeddings = model.encode_queries(input_ids, attention_mask)
            else:
                embeddings = model.encode_documents(input_ids, attention_mask)

            # Move to CPU immediately and collect
            all_embeddings.append(embeddings.cpu().numpy())
            all_ids.extend(batch_ids)

            # Clear CUDA cache to free memory
            del embeddings, input_ids, attention_mask
            if device.type == "cuda":
                torch.cuda.empty_cache()

    # Concatenate all embeddings
    all_embeddings = np.concatenate(all_embeddings, axis=0)

    logger.info(
        f"Generated embeddings for {len(all_ids)} {'queries' if is_query else 'documents'}"
    )
    logger.info(f"Embeddings shape: {all_embeddings.shape}")

    return all_embeddings, all_ids


def save_embeddings(
    embeddings: np.ndarray,
    ids: List[str],
    save_path: str,
):
    """
    Save embeddings and IDs to disk.

    Args:
        embeddings: Embeddings array (num_items, embedding_dim)
        ids: List of IDs corresponding to embeddings
        save_path: Path to save (without extension)
    """
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    # Save embeddings as npy
    np.save(f"{save_path}_embeddings.npy", embeddings)

    # Save IDs as text file
    with open(f"{save_path}_ids.txt", "w") as f:
        for id_ in ids:
            f.write(f"{id_}\n")

    logger.info(f"Saved embeddings to {save_path}_embeddings.npy")
    logger.info(f"Saved IDs to {save_path}_ids.txt")


def load_embeddings(
    load_path: str,
) -> Tuple[np.ndarray, List[str]]:
    """
    Load embeddings and IDs from disk.

    Args:
        load_path: Path to load from (without extension)

    Returns:
        Tuple of (embeddings, ids)
    """
    embeddings = np.load(f"{load_path}_embeddings.npy")

    with open(f"{load_path}_ids.txt", "r") as f:
        ids = [line.strip() for line in f]

    logger.info(f"Loaded embeddings from {load_path}_embeddings.npy")
    logger.info(f"Loaded {len(ids)} IDs")

    return embeddings, ids


def generate_embeddings_distributed(
    model: torch.nn.Module,
    ids: List[str],
    id_to_text: Dict[str, str],
    tokenizer,
    max_length: int,
    batch_size: int,
    device: torch.device,
    is_query: bool = True,
    show_progress: bool = True,
    rank: int = 0,
    world_size: int = 1,
    prefix: str = "",
) -> Tuple[np.ndarray, List[str]]:
    """
    Generate embeddings for a list of texts in a distributed manner.

    Each rank processes a shard of the data, then:
    1. Gather embeddings to rank 0
    2. Rank 0 has complete embeddings
    3. Broadcast complete embeddings to all ranks

    This eliminates redundant computation where all ranks process all data.

    Args:
        model: The dual encoder model
        ids: List of IDs (query IDs or doc IDs) - FULL list on all ranks
        id_to_text: Mapping from ID to text - FULL mapping on all ranks
        tokenizer: HuggingFace tokenizer
        max_length: Maximum sequence length
        batch_size: Batch size for encoding
        device: Device to run on
        is_query: Whether encoding queries (True) or documents (False)
        show_progress: Whether to show progress bar (only on rank 0)
        rank: Current process rank
        world_size: Total number of processes
        prefix: Prefix to prepend to texts (e.g., "query: " or "passage: ")

    Returns:
        Tuple of (embeddings array, ordered list of IDs) - SAME on all ranks
        embeddings: (num_items, embedding_dim)
        ids: List of IDs in the same order as embeddings
    """
    model.eval()

    # Step 1: Shard the data across ranks
    total_items = len(ids)
    items_per_rank = (total_items + world_size - 1) // world_size  # Ceiling division

    start_idx = rank * items_per_rank
    end_idx = min(start_idx + items_per_rank, total_items)

    # Each rank gets its shard
    local_ids = ids[start_idx:end_idx]
    local_texts = [id_to_text[id_] for id_ in local_ids]

    if rank == 0:
        logger.info(
            f"Distributed embedding generation: {total_items} items across {world_size} ranks"
        )
        logger.info(f"Each rank processes ~{items_per_rank} items")

    logger.info(
        f"Rank {rank}: processing {len(local_ids)} items (indices {start_idx}:{end_idx})"
    )

    # Step 2: Each rank generates embeddings for its shard
    if len(local_ids) > 0:
        dataset = SimpleTextDataset(local_ids, local_texts)
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
        )

        local_embeddings = []

        with torch.no_grad():
            iterator = (
                tqdm(
                    dataloader,
                    desc=f"Rank {rank}: Encoding {'queries' if is_query else 'documents'}",
                )
                if (show_progress and rank == 0)
                else dataloader
            )

            for batch in iterator:
                batch_texts = batch["text"]

                # Apply prefix if specified
                if prefix:
                    batch_texts = [prefix + text for text in batch_texts]

                # Tokenize
                encoded = tokenizer(
                    batch_texts,
                    max_length=max_length,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt",
                )

                # Move to device
                input_ids = encoded["input_ids"].to(device)
                attention_mask = encoded["attention_mask"].to(device)

                # Encode
                if is_query:
                    embeddings = model.encode_queries(input_ids, attention_mask)
                else:
                    embeddings = model.encode_documents(input_ids, attention_mask)

                # Move to CPU immediately
                local_embeddings.append(embeddings.cpu())

                # Clear CUDA cache
                del embeddings, input_ids, attention_mask
                if device.type == "cuda":
                    torch.cuda.empty_cache()

        # Concatenate local embeddings
        local_embeddings = torch.cat(local_embeddings, dim=0)  # (local_size, dim)
    else:
        # This rank has no data (happens if world_size > total_items)
        # Create empty tensor with correct embedding dimension
        # We'll get the dimension from rank 0 later
        local_embeddings = None

    logger.info(f"Rank {rank}: generated {len(local_ids)} embeddings")

    # Step 3: Synchronize to ensure all ranks are ready
    if world_size > 1:
        dist.barrier()

    # Step 4: Gather embeddings to rank 0
    if world_size > 1:
        if rank == 0:
            logger.info("Gathering embeddings to rank 0...")

        # First, gather the sizes from all ranks
        local_size = len(local_ids)
        all_sizes = [
            torch.zeros(1, dtype=torch.long, device=device) for _ in range(world_size)
        ]
        dist.all_gather(
            all_sizes, torch.tensor([local_size], dtype=torch.long, device=device)
        )
        all_sizes = [int(s.item()) for s in all_sizes]

        if rank == 0:
            logger.info(f"Sizes per rank: {all_sizes}")

        # Get embedding dimension
        if local_embeddings is not None:
            embedding_dim = local_embeddings.shape[1]
        else:
            embedding_dim = 0

        # Broadcast embedding dimension from rank 0 (or first rank with data)
        embedding_dim_tensor = torch.tensor(
            [embedding_dim], dtype=torch.long, device=device
        )
        dist.all_reduce(embedding_dim_tensor, op=dist.ReduceOp.MAX)
        embedding_dim = int(embedding_dim_tensor.item())

        # Prepare local embeddings for gathering (move to GPU)
        if local_embeddings is not None:
            local_embeddings_gpu = local_embeddings.to(device)
        else:
            local_embeddings_gpu = torch.zeros((0, embedding_dim), device=device)

        # Pad local embeddings to max size for gather
        max_size = max(all_sizes)
        if local_embeddings_gpu.shape[0] < max_size:
            padding = torch.zeros(
                (max_size - local_embeddings_gpu.shape[0], embedding_dim), device=device
            )
            local_embeddings_gpu = torch.cat([local_embeddings_gpu, padding], dim=0)

        # Gather to rank 0
        if rank == 0:
            gathered_embeddings = [
                torch.zeros((max_size, embedding_dim), device=device)
                for _ in range(world_size)
            ]
        else:
            gathered_embeddings = None

        dist.gather(local_embeddings_gpu, gathered_embeddings, dst=0)

        # Rank 0: concatenate and trim to actual sizes
        if rank == 0:
            all_embeddings_list = []
            for i, emb_tensor in enumerate(gathered_embeddings):
                actual_size = all_sizes[i]
                if actual_size > 0:
                    all_embeddings_list.append(emb_tensor[:actual_size])

            all_embeddings = torch.cat(all_embeddings_list, dim=0).cpu().numpy()
            logger.info(
                f"Rank 0: gathered complete embeddings, shape: {all_embeddings.shape}"
            )
        else:
            all_embeddings = None

        # Clean up
        del local_embeddings_gpu
        if device.type == "cuda":
            torch.cuda.empty_cache()

    else:
        # Single process, no gathering needed
        all_embeddings = (
            local_embeddings.numpy() if local_embeddings is not None else np.array([])
        )

    # Step 5: Broadcast embeddings from rank 0 to all ranks
    if world_size > 1:
        if rank == 0:
            logger.info("Broadcasting embeddings to all ranks...")

        # First broadcast shape
        if rank == 0:
            shape = torch.tensor(all_embeddings.shape, dtype=torch.long, device=device)
        else:
            shape = torch.zeros(2, dtype=torch.long, device=device)

        dist.broadcast(shape, src=0)

        # Prepare buffer on non-rank-0 processes
        if rank != 0:
            all_embeddings = np.zeros(tuple(shape.cpu().numpy()), dtype=np.float32)

        # Convert to tensor for broadcasting
        embeddings_tensor = torch.from_numpy(all_embeddings).to(device)
        dist.broadcast(embeddings_tensor, src=0)

        # Convert back to numpy
        all_embeddings = embeddings_tensor.cpu().numpy()

        logger.info(
            f"Rank {rank}: received complete embeddings, shape: {all_embeddings.shape}"
        )

        # Clean up
        del embeddings_tensor
        if device.type == "cuda":
            torch.cuda.empty_cache()

    # Step 6: Synchronize to ensure all ranks have embeddings
    if world_size > 1:
        dist.barrier()

    if rank == 0:
        logger.info(f"Distributed embedding generation complete")
        logger.info(f"Final embeddings shape: {all_embeddings.shape}")

    return all_embeddings, ids
