"""Negative sampling strategies for EBM training.

This module provides various methods for generating negative samples for contrastive
learning. These include procedural methods that alter a given text (e.g., by
removing sentences or tokens) and dataset-based methods that pull samples from
other contexts.
"""

from __future__ import annotations

import random
from typing import TYPE_CHECKING, List, Union

import torch
from nltk.tokenize import sent_tokenize, word_tokenize

if TYPE_CHECKING:
    from .datasets import PromptResponseDataset
    from .config import DataConfig, TrainingConfig 


def mask_sentences(
    response: str,
    drop_prob: float = 0.3,
    force_drop_at_least_one: bool = True,
) -> str:
    """Creates a negative sample by randomly removing sentences from a text.

    To allow model to capture concept continuity, each sentence is independently 
    dropped with a given probability. This method ensures that multi-sentence 
    texts are always corrupted.

    Args:
        response (str): The input text to corrupt.
        drop_prob (float): The probability of dropping each sentence.
        force_drop_at_least_one (bool): If True and no sentence is dropped by
            random chance, one sentence will be forcibly removed to guarantee
            corruption.

    Returns:
        str: The corrupted text with some sentences removed. Returns the original
             text if it contains only one sentence.
    """
    sentences = sent_tokenize(response)
    if not sentences:
        return ""  # degenerate case

    elif len(sentences) == 1:
        return response

    keep_flags = [random.random() >= drop_prob for _ in sentences]

    # guarantee at least one drop
    if force_drop_at_least_one and all(keep_flags):
        keep_flags[random.randrange(len(sentences))] = False

    kept = [s for s, keep in zip(sentences, keep_flags) if keep]

    # safeguard: if everything got removed, keep the first sentence
    if not kept:
        kept.append(sentences[0])

    return " ".join(kept)


def mask_tokens(
    text: str,
    drop_prob: float = 0.3,
    force_drop_at_least_one: bool = True,
    shuffle_tokens: bool = True,
) -> str:
    """Creates a negative sample by randomly dropping and shuffling tokens.

    To allow model to capture grammar structure, each toek is independently 
    dropped with a given probability. This method ensures that at least one
    token is dropped, and at least one token is kept. 
    
    Args:
        text (str): The input text to corrupt.
        drop_prob (float): The probability of dropping each token.
        force_drop_at_least_one (bool): If True, ensures at least one token is
            dropped to guarantee corruption.
        shuffle_tokens (bool): If True, randomly permutes the remaining tokens.

    Returns:
        str: The corrupted text with some tokens removed and/or shuffled.
    """
    tokens = word_tokenize(text)
    if not tokens:
        return ""

    keep = [random.random() >= drop_prob for _ in tokens]

    if force_drop_at_least_one and all(keep):
        i = random.randrange(len(tokens))
        keep[i] = False

    kept = [tok for tok, k in zip(tokens, keep) if k]

    # If somehow all got dropped, put back the first token
    if not kept:
        kept = [tokens[0]]

    if shuffle_tokens:
        random.shuffle(kept)

    return " ".join(kept)


def sample_negative_responses(
    sampling_type: str,
    responses: List[str],
    prompts: List[str] = None,
    indices: List[int] = None,
    model: torch.nn.Module = None,
    dataset: Union["PromptResponseDataset", None] = None,
    offset: int = 1,
    train_cfg: "TrainingConfig" = None,
    data_cfg: "DataConfig" = None,
) -> List[str]:
    """Generates a batch of negative responses using a specified strategy.

    This function acts as a dispatcher, calling the appropriate negative sampling
    method. The required arguments depend on the chosen `sampling_type`.

    Args:
        sampling_type (str): The strategy to use for negative sampling.
            Supported types:
                - 'sentence_masking': Randomly removes sentences from the response.
                - 'sentence_masking_hard': Creates k sentence-masked candidates
                  and selects the one with the lowest energy.
                - 'token_masking': Randomly drops and shuffles tokens in the response.
                - 'off_context': Uses the response from the next sample in the dataset.
                  If dataset is not provided, it uses the next sample in the batch.
                - 'off_context_batch': Uses a response from another sample
                  within the same batch.
                - 'gpt2': Uses a pre-computed response generated by GPT-2.
                - 'human': Uses a pre-existing human-written response from the dataset.
                - 'langevin': Shuffles responses within the batch.
        responses (List[str]): A list of ground-truth responses for the batch.
            Required by all procedural methods.
        prompts (List[str], optional): Prompts for the batch.
            Required for: 'sentence_masking_hard'.
        indices (List[int], optional): Original dataset indices for the batch.
            Required for: 'off_context', 'gpt2', 'human'.
        model (torch.nn.Module, optional): The energy model.
            Required for: 'sentence_masking_hard'.
        dataset (PromptResponseDataset, optional): The full  dataset object,
            Required for: context-based sampling ('off_context', 'gpt2', 'human').
        offset (int, optional): The index offset for 'off_context_batch'. Defaults to 1.
        train_cfg (TrainingConfig, optional): Config with params like `k_candidates`.
            Required for: 'sentence_masking_hard'.
        data_cfg (DataConfig, optional): Config with data params like column names. 
            Required for: 'gpt2', 'human'.

    Returns:
        List[str]: A batch of generated negative responses.
    """
    if sampling_type == "sentence_masking":
        return [mask_sentences(r) for r in responses]

    elif sampling_type == "sentence_masking_hard":
        if model is None or train_cfg is None or prompts is None:
            msg = "`model`, `train_cfg`, and `prompts` are required for 'sentence_masking_hard'."
            raise ValueError(msg)
        
        batch = []
        k_candidates = train_cfg.k_candidates 
        for r in responses:
            batch.extend([mask_sentences(r) for _ in range(k_candidates)])
        prompts_rep = []
        for x in prompts:
            prompts_rep.extend([x] * k_candidates)
        with torch.no_grad():
            all_energies = model(prompts_rep, batch).view(len(prompts), k_candidates)

        # Pick hardest (lowest neg_energy)
        hardest = []
        for i in range(len(prompts)):
            idx = torch.argmin(all_energies[i]).item()
            hardest.append(batch[i * k_candidates + idx])
        return hardest

    elif sampling_type == "token_masking":
        return [mask_tokens(r) for r in responses]

    elif sampling_type == "off_context":
        # If dataset is available, use the next sample from the full dataset
        if dataset and indices:
            n = len(dataset)
            return [dataset.responses[(i + 1) % n] for i in indices]
        # Otherwise, fall back to using the next sample in the current batch
        else:
            batch_size = len(responses)
            if batch_size == 1:
                msg = "Cannot use 'off_context' batch fallback with a batch size of 1."
                raise ValueError(msg)
            return [responses[(i + 1) % batch_size] for i in range(batch_size)]

    elif sampling_type == "off_context_batch":
        batch_size = len(responses)
        # For each response i, use response (i+1) % B as a negative
        return [responses[(i + offset) % batch_size] for i in range(batch_size)]

    elif sampling_type == "gpt2":
        if dataset is None or data_cfg is None or indices is None:
            msg = "`dataset`, `data_cfg`, and `indices` are required for 'gpt2'."
            raise ValueError(msg)
        return [dataset.get_field(i, data_cfg.gpt2_col) for i in indices]

    elif sampling_type == "human":
        if dataset is None or data_cfg is None or indices is None:
            msg = "`dataset`, `data_cfg`, and `indices` are required for 'human'."
            raise ValueError(msg)
        return [dataset.get_field(i, data_cfg.human_col) for i in indices]

    elif sampling_type == "langevin":
        # NOTE: Not implemented
        negs = responses.copy()
        random.shuffle(negs)
        return negs

    msg = f"Unknown response sampling_type: {sampling_type}"
    raise ValueError(msg)


def sample_negative_prompts(
    sampling_type: str,
    prompts: List[str],
    indices: List[int] = None,
    dataset: Union["PromptResponseDataset", None] = None,
) -> List[str]:
    """Generates a batch of negative prompts using a specified strategy.

    This function is analogous to `sample_negative_responses` but corrupts the
    prompts instead, keeping the responses fixed.

    Args:
        sampling_type (str): The strategy to use for negative prompt sampling.
            Supported types:
                - 'sentence_masking_prompt': Randomly removes sentences in the prompt.
                - 'token_masking_prompt': Randomly drops tokens in the prompt.
                - 'off_context_prompt': Uses the prompt from the next sample in dataset.
                  If dataset is not provided, it uses the next sample in the batch.
        prompts (List[str]): A list of input prompts to be corrupted.
        indices (List[int], optional): Original dataset indices for the batch.
            Required for: 'off_context_prompt'.
        dataset (PromptResponseDataset, optional): The dataset object.
            Required for: 'off_context_prompt'.

    Returns:
        List[str]: A batch of generated negative prompts.
    """
    if sampling_type == "sentence_masking_prompt":
        return [mask_sentences(p) for p in prompts]

    elif sampling_type == "token_masking_prompt":
        return [mask_tokens(p) for p in prompts]

    elif sampling_type == "off_context_prompt":
        # If dataset is available, use the next sample from the full dataset
        if dataset and indices:
            n = len(dataset)
            return [dataset.prompts[(i + 1) % n] for i in indices]
        # Otherwise, fall back to using the next sample in the current batch
        else:
            batch_size = len(prompts)
            if batch_size == 1:
                msg = "Cannot use 'off_context_prompt' batch fallback with a batch size of 1."
                raise ValueError(msg)
            return [prompts[(i + 1) % batch_size] for i in range(batch_size)]

    msg = f"Unknown prompt sampling_type: {sampling_type}"
    raise ValueError(msg)
