import torch
from datasets import concatenate_datasets, load_dataset
from transformers import set_seed, DynamicCache
from torch.amp import autocast

def stack_kv_cache(kv_cache):
    """
    Stacks a HuggingFace-style KV cache into a single 6D tensor:
    [2, num_layers, batch, heads, seq_len, head_dim]
    """
    keys = [k.unsqueeze(0) for k, _ in kv_cache]   # shape: [1, B, H, T, D]
    values = [v.unsqueeze(0) for _, v in kv_cache] # shape: [1, B, H, T, D]

    stacked_keys = torch.stack(keys, dim=1)    # [1, num_layers, B, H, T, D]
    stacked_values = torch.stack(values, dim=1)  # same

    # Combine into one tensor: [2, num_layers, B, H, T, D]
    kv_tensor = torch.cat([stacked_keys, stacked_values], dim=0)
    return kv_tensor

def unstack_kv_tensor(kv_tensor: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor]]:
    """
    Converts a stacked KV tensor of shape [2, num_layers, batch, heads, seq_len, head_dim]
    back into the HuggingFace-style KV cache:
        ((k_0, v_0), (k_1, v_1), ..., (k_L, v_L))
    """
    if kv_tensor is None:
        return None
    assert kv_tensor.ndim == 6
    assert kv_tensor.shape[0] == 2  # key/value
    num_layers = kv_tensor.shape[1]

    keys = kv_tensor[0]  # [num_layers, batch, heads, seq_len, head_dim]
    values = kv_tensor[1]  # same

    kv_cache = tuple((keys[i], values[i]) for i in range(num_layers))
    return kv_cache

def extract_and_stack_last_token_kv(past_key_values):
    """
    Given a HuggingFace-style past_key_values (tuple of (k, v) per layer),
    extracts the last token (seq_len - 1) and stacks into a tensor of shape:
    [2, num_layers, batch, num_heads, 1, head_dim]
    """
    keys, values = zip(*[
        (k[:, :, -1:, :].contiguous(), v[:, :, -1:, :].contiguous())
        for k, v in past_key_values
    ])
    stacked_keys = torch.stack(keys, dim=0)   # [num_layers, batch, heads, 1, head_dim]
    stacked_values = torch.stack(values, dim=0)
    return torch.stack([stacked_keys, stacked_values], dim=0)   # [2, num_layers, batch, heads, 1, head_dim]

def concat_stacked_kv_caches(prev_kv_tensor, last_token_kv_tensor):
    """
    Concatenates two stacked KV tensors along the sequence dimension (dim=4).
    Both tensors should be of shape:
    [2, num_layers, batch, heads, seq_len, head_dim]
    
    This appends the new token to the existing sequence.
    """
    if prev_kv_tensor is None:
        return last_token_kv_tensor
    assert prev_kv_tensor.shape[:2] == last_token_kv_tensor.shape[:2], "Mismatch in (2, num_layers)"
    assert prev_kv_tensor.shape[2:4] == last_token_kv_tensor.shape[2:4], "Mismatch in (batch, heads)"
    assert last_token_kv_tensor.shape[4] == 1, "Last token cache must have sequence length 1"

    # Concatenate along seq_len dimension (dim=4)
    return torch.cat([prev_kv_tensor, last_token_kv_tensor], dim=4)

def make_logp_fn(fmodel, buffers, input_ids, attention_mask):
    def logp_fn(params, past_key_values=None):  # params unpacked by jvp
        dtype = params[0].dtype
        # with autocast(device_type='cuda', dtype=dtype):

        
        if past_key_values is None:
            outputs = fmodel(
                params,
                buffers,
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            logits = outputs.logits  # get all tokens if no past key values
            log_probs = torch.log_softmax(logits.float(), dim=-1)
            # print(log_probs.shape)
            return log_probs, stack_kv_cache(outputs.past_key_values)
        else:
            past_key_values = unstack_kv_tensor(past_key_values)

            outputs = fmodel(
                params,
                buffers,
                input_ids=input_ids,
                attention_mask=attention_mask,
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
            )
            logits = outputs.logits  # [batch_size, 1, vocab], just get the last token
            log_probs = torch.log_softmax(logits.float(), dim=-1)
            last_kv_cache = extract_and_stack_last_token_kv(outputs.past_key_values)
            # print(log_probs.shape)
            return log_probs, last_kv_cache
    return logp_fn