import configparser
import model
import json
import jax
import jax.numpy as jnp
from snapshot import Snapshot
import os
import mlflow
import numpy as np


def prepare_training_env(tracking_uri: str, mem_fraction: str=".99", mlflow_config: str="~/.mlflow/credentials.ini"):
    os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = mem_fraction

    mlflow.set_tracking_uri(tracking_uri)
    mlflow.end_run()
    config = configparser.ConfigParser()
    config_path = os.path.expanduser(mlflow_config)
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Mlflow credentials not found at {config_path}")

    config.read(config_path)

    if 'mlflow' in config:
        # Set environment variables for Mlflow authentication
        if 'MLFLOW_TRACKING_USERNAME' in config['mlflow']:
            os.environ['MLFLOW_TRACKING_USERNAME'] = config['mlflow']['MLFLOW_TRACKING_USERNAME']
        if 'MLFLOW_TRACKING_PASSWORD' in config['mlflow']:
            os.environ['MLFLOW_TRACKING_PASSWORD'] = config['mlflow']['MLFLOW_TRACKING_PASSWORD']
        if 'MLFLOW_TRACKING_TOKEN' in config['mlflow']:
            os.environ['MLFLOW_TRACKING_TOKEN'] = config['mlflow']['MLFLOW_TRACKING_TOKEN']


def top_k_sampling(logits, k, temperature=1.0):
    """Apply top-k sampling to logits.

    Args:
        logits: Shape (batch, seq_len, vocab_size) or (vocab_size,)
        k: Number of top tokens to keep
        temperature: Sampling temperature

    Returns:
        Masked logits with same shape as input
    """
    if k <= 0:
        return logits

    # Apply temperature
    logits = logits / temperature

    # Get top-k values and indices
    top_k_values, top_k_indices = jax.lax.top_k(logits, k)

    # Create mask for top-k tokens - start with -inf everywhere
    mask = jnp.full_like(logits, -jnp.inf)

    # For batched inputs, we need to use advanced indexing
    if logits.ndim == 3:  # (batch, seq_len, vocab_size)
        batch_size, seq_len, vocab_size = logits.shape
        batch_indices = jnp.arange(batch_size)[:, None, None]  # (batch, 1, 1)
        seq_indices = jnp.arange(seq_len)[None, :, None]  # (1, seq, 1)

        # Set the top-k values at their respective indices
        mask = mask.at[batch_indices, seq_indices, top_k_indices].set(top_k_values)
    elif logits.ndim == 2:  # (batch, vocab_size)
        batch_indices = jnp.arange(logits.shape[0])[:, None]
        mask = mask.at[batch_indices, top_k_indices].set(top_k_values)
    else:  # (vocab_size,)
        mask = mask.at[top_k_indices].set(top_k_values)

    return mask


def top_p_sampling(logits, p, temperature=1.0):
    """Apply top-p (nucleus) sampling to logits.

    Args:
        logits: Shape (batch, seq_len, vocab_size) or (vocab_size,)
        p: Cumulative probability threshold
        temperature: Sampling temperature

    Returns:
        Masked logits with same shape as input
    """
    if p >= 1.0:
        return logits / temperature

    # Apply temperature
    logits = logits / temperature

    # Sort logits in descending order along vocab dimension
    sorted_indices = jnp.argsort(logits, axis=-1)[..., ::-1]
    sorted_logits = jnp.take_along_axis(logits, sorted_indices, axis=-1)

    # Calculate cumulative probabilities
    probs = jax.nn.softmax(sorted_logits, axis=-1)
    cumulative_probs = jnp.cumsum(probs, axis=-1)

    # Find cutoff point - keep tokens until cumsum > p
    cutoff_mask = cumulative_probs <= p

    # Always keep at least the first (highest prob) token
    cutoff_mask = cutoff_mask.at[..., 0].set(True)

    # Use the cutoff mask to select valid tokens
    valid_sorted_logits = jnp.where(cutoff_mask, sorted_logits, -jnp.inf)

    # Scatter back to original positions
    mask = jnp.take_along_axis(
        jnp.expand_dims(valid_sorted_logits, axis=-1),
        jnp.expand_dims(jnp.argsort(sorted_indices, axis=-1), axis=-1),
        axis=-1
    ).squeeze(-1)

    return mask


def min_p_sampling(logits, min_p, temperature=1.0):
    """Apply min-p sampling to logits.

    Args:
        logits: Shape (batch, seq_len, vocab_size) or (vocab_size,)
        min_p: Minimum probability threshold as fraction of max prob
        temperature: Sampling temperature

    Returns:
        Masked logits with same shape as input
    """
    if min_p <= 0.0:
        return logits / temperature

    # Apply temperature
    logits = logits / temperature

    # Convert to probabilities
    probs = jax.nn.softmax(logits, axis=-1)

    # Find max probability along vocab dimension
    max_prob = jnp.max(probs, axis=-1, keepdims=True)

    # Create mask for tokens with prob >= min_p * max_prob
    threshold = min_p * max_prob
    mask = jnp.where(probs >= threshold, logits, -jnp.inf)

    return mask


def apply_repetition_penalty(logits, token_sequence, penalty=1.0):
    """Apply repetition penalty to logits.

    Args:
        logits: Shape (batch, seq_len, vocab_size)
        token_sequence: Shape (batch, history_len) - tokens generated so far
        penalty: Repetition penalty factor

    Returns:
        Penalized logits with same shape as input
    """
    if penalty == 1.0 or token_sequence is None:
        return logits

    batch_size, seq_len, vocab_size = logits.shape

    # Initialize penalty mask as all 1s (no penalty)
    penalty_mask = jnp.ones((batch_size, seq_len, vocab_size))

    # For each batch, count token frequencies
    for b in range(batch_size):
        # Count occurrences of each token in this batch's sequence
        token_counts = jnp.zeros(vocab_size, dtype=jnp.int32)
        for token in token_sequence[b]:
            token_counts = token_counts.at[token].add(1)

        # Apply penalty to tokens that appeared in the sequence
        batch_penalty = jnp.where(token_counts > 0, penalty, 1.0)
        penalty_mask = penalty_mask.at[b].set(batch_penalty)

    # Apply penalty (divide if > 1.0, multiply if < 1.0)
    penalized_logits = jnp.where(
        penalty_mask > 1.0,
        logits / penalty_mask,
        logits * penalty_mask
    )

    return penalized_logits


def sample_tokens(logits, temperature=1.0, top_k=0, top_p=1.0, min_p=0.0,
                  repetition_penalty=1.0, token_sequence=None, key=None):
    """Sample tokens from logits with various sampling strategies.

    Args:
        logits: Shape (batch, seq_len, vocab_size)
        temperature: Sampling temperature
        top_k: Number of top tokens for top-k sampling (0 to disable)
        top_p: Cumulative probability for nucleus sampling (1.0 to disable)
        min_p: Minimum probability threshold for min-p sampling (0.0 to disable)
        repetition_penalty: Penalty for repeated tokens (1.0 for no penalty)
        token_sequence: Previous tokens for repetition penalty, shape (batch, history_len)
        key: JAX random key

    Returns:
        Sampled tokens with shape (batch, seq_len)
    """

    # Apply repetition penalty
    if repetition_penalty != 1.0 and token_sequence is not None:
        logits = apply_repetition_penalty(logits, token_sequence, repetition_penalty)

    # Apply sampling strategies in order
    if top_k > 0:
        logits = top_k_sampling(logits, top_k, temperature)
    elif top_p < 1.0:
        logits = top_p_sampling(logits, top_p, temperature)
    elif min_p > 0.0:
        logits = min_p_sampling(logits, min_p, temperature)
    else:
        # Just apply temperature
        logits = logits / temperature

    # Sample from the processed logits
    if key is None:
        key = jax.random.key(0)

    tokens = jax.random.categorical(key, logits, axis=-1)
    return tokens.astype(jnp.int32)


def load_reasoner(reasoner_snap: str, vocab_size: int, size: str):
    """ Load JEPA-Reasoner in inference mode """
    with open('configs/reasoner_cfg.json', 'r') as file:
        config = json.load(file)

        feature = int(config[f'model_{size}']['Feature'])
        attn_feature = int(config[f'model_{size}']['ATTN Feature'])
        ffn_feature = int(config[f'model_{size}']['FFN Feature'])
        num_head = int(config[f'model_{size}']['Head Count'])
        decoder_count = int(config[f'model_{size}']['Decoder Count'])
        init_scalar = float(config[f'model_{size}']['Init Scalar'])
        max_len = int(config[f'model_{size}']['Max Length'])
        rope_base = float(config[f'model_{size}']['RoPE Base'])

    reasoner = model.Reasoner(
        feature=feature,
        attn_feature=attn_feature,
        ffn_feature=ffn_feature,
        num_head=num_head,
        decoder_count=decoder_count,
        is_causal=True,
        init_scalar=init_scalar,
        vocab_size=vocab_size,
        key=jax.random.key(0),
        dtype=jnp.bfloat16
    )
    reasoner.eval(
        rope_base=rope_base,
        max_len=max_len
    )

    snap = Snapshot(os.path.dirname(reasoner_snap))
    return snap.load(os.path.basename(reasoner_snap), reasoner, skip_ema=True)


def load_models(reasoner_snap: str, talker_snap: str | None = None, vocab_size: int=32, size: str='large'):
    # Load configuration
    with open('configs/reasoner_cfg.json', 'r') as file:
        config = json.load(file)
        latent_feature = int(config[f'model_{size}']['Feature'])

    with open('configs/talker_cfg.json', 'r') as file:
        config = json.load(file)

        feature = int(config[f'model_{size}']['Feature'])
        attn_feature = int(config[f'model_{size}']['ATTN Feature'])
        ffn_feature = int(config[f'model_{size}']['FFN Feature'])
        num_head = int(config[f'model_{size}']['Head Count'])
        decoder_count = int(config[f'model_{size}']['Decoder Count'])
        encoder_count = int(config[f'model_{size}']['Decoder Count'])
        init_scalar = float(config[f'model_{size}']['Init Scalar'])
        max_len = int(config[f'model_{size}']['Max Length'])
        rope_base = float(config[f'model_{size}']['RoPE Base'])

    # Build model instance
    key = jax.random.key(0)
    reasoner = load_reasoner(reasoner_snap, vocab_size, size=size)
    talker = model.DualTalker(
        feature=feature,
        latent_feature=latent_feature,
        attn_feature=attn_feature,
        ffn_feature=ffn_feature,
        num_head=num_head,
        encoder_count=encoder_count, # Use same number of encoders
        decoder_count=decoder_count,
        init_scalar=init_scalar,
        vocab_size=vocab_size,
        key=key,
        dtype=jnp.bfloat16
    )
    talker.eval(
        rope_base=rope_base,
        max_len=max_len
    )

    snap = Snapshot(os.path.dirname(talker_snap))
    talker = snap.load(os.path.basename(talker_snap), talker)

    return reasoner, talker


def add_noise_to_data(token_data, valid_lengths, replacement_tokens, replacement_prob=0.1, random_seed=None):
    """
    Randomly replace non-padding tokens with specified replacement tokens.

    Args:
        token_data (np.ndarray): Token data of shape (batch, seq_len) with right padding
        valid_lengths (np.ndarray): Array of shape (batch,) containing number of non-padding elements per sequence
        replacement_tokens (list or np.ndarray): Tokens to randomly choose from for replacement
        replacement_prob (float): Probability of replacing each non-padding token (default: 0.1)
        random_seed (int, optional): Random seed for reproducibility

    Returns:
        np.ndarray: Modified token data with same shape as input
    """
    if random_seed is not None:
        np.random.seed(random_seed)

    # Create a copy to avoid modifying the original data
    result = token_data.copy()
    batch_size, seq_len = token_data.shape

    # Convert replacement_tokens to numpy array if it isn't already
    replacement_tokens = np.array(replacement_tokens)

    for batch_idx in range(batch_size):
        valid_len = valid_lengths[batch_idx]

        # Generate random mask for positions to replace (only for non-padding positions)
        replace_mask = np.random.random(valid_len) < replacement_prob

        # Get indices of positions to replace
        replace_indices = np.where(replace_mask)[0]

        if len(replace_indices) > 0:
            # Randomly select replacement tokens for each position to replace
            selected_replacements = np.random.choice(replacement_tokens, size=len(replace_indices))

            # Apply replacements
            result[batch_idx, replace_indices] = selected_replacements

    return result


def add_gaussian_noise_to_array(key, array, noise_ratio, magnitude):
    """
    Adds Gaussian noise to the embedding vectors.
    param: key: A JAX random key.
    param: embeddings: The embedding tensor (batch_size, seq_len, embed_dim).
    param: mask: The padding mask where 1 indicates a non-padding token.
    param: noise_ratio: Standard deviation of the Gaussian noise.
    Returns: A tuple containing the noisy embeddings tensor and an updated JAX random key.
    """
    keys = jax.random.split(key, num=2)
    noise = jax.random.normal(keys[0], shape=array.shape, dtype=array.dtype) * magnitude * noise_ratio
    array = array + noise
    return array, keys[1]

