import math
import random
import torch

from typing import Optional
from torch.utils.data._utils.collate import collate_tensor_fn
from .tokenizer import Tokenizer


def pass_text(row, tokenizer, add_bos, add_eos):
    input_string = row["text"]
    input_tokens = tokenizer.encode(input_string, bos=add_bos, eos=add_eos)
    # label_tokens = input_tokens.clone()
    return (input_tokens, None)

def pass_retrieval_pair(row, tokenizer, add_bos, add_eos, train_group_size=None):
    prefix_string = row[row["data_signature"]["keys"][0]]
    suffix_string = row[row["data_signature"]["keys"][1]]
    
    ########### Logic to handle hard negatives ###########
    if len(row["data_signature"]["keys"]) == 3:
        suffixes = []
        psitive = random.choice(row[row["data_signature"]["keys"][1]]) if isinstance(row[row["data_signature"]["keys"][1]], list) else row[row["data_signature"]["keys"][1]]
        suffixes.append(psitive)
        if len(row[row["data_signature"]["keys"][2]]) < train_group_size - 1:
            num = math.ceil((train_group_size - 1) / len(row[row["data_signature"]["keys"][2]]))
            negatives = random.sample(row[row["data_signature"]["keys"][2]] * num, train_group_size - 1)
        else:
            negatives = random.sample(row[row["data_signature"]["keys"][2]], train_group_size - 1)
        suffixes.extend(negatives)

        input_tokens = tokenizer.encode(prefix_string, bos=add_bos, eos=add_eos)
        target_tokens = [tokenizer.encode(suffix, bos=add_bos, eos=add_eos) for suffix in suffixes]

    else:
        input_tokens = tokenizer.encode(prefix_string, bos=add_bos, eos=add_eos)
        target_tokens = tokenizer.encode(suffix_string, bos=add_bos, eos=add_eos)
    
    return (input_tokens, target_tokens)


def concat_input_target(row, tokenizer, add_bos, add_eos):
    input_string = row["input"] + row["target"]
    input_tokens = tokenizer.encode(input_string, bos=add_bos, eos=add_eos)
    # label_tokens = input_tokens.clone()
    return (input_tokens, None)


def condition_input_supervise_target(row, tokenizer, add_bos, add_eos):
    input_string = row["input"]
    joint_string = row["input"] + row["target"]
    input_tokens = tokenizer.encode(input_string, bos=add_bos, eos=False)
    joint_tokens = tokenizer.encode(joint_string, bos=add_bos, eos=add_eos)
    label_tokens = joint_tokens.clone()
    # mask the locations of the input tokens in the joint tokens
    label_tokens[0 : len(input_tokens)] = tokenizer.pad_id
    input_tokens = joint_tokens
    return (input_tokens, label_tokens)


def apply_chat_template_supervise_all(row, tokenizer, add_bos, add_eos):
    assert (
        len(row["data_signature"]["keys"]) == 1
    ), "Ambiguous row format for chat template call. data signature should spec the single intended key."
    key = row["data_signature"]["keys"][0]
    input_string = tokenizer.processor.apply_chat_template(row[key], tokenize=False)
    input_tokens = tokenizer.encode(input_string, bos=add_bos, eos=add_eos)
    label_tokens = input_tokens.clone()
    return (input_tokens, label_tokens)


def apply_chat_template_supervise_assistant(row, tokenizer, add_bos, add_eos):
    raise NotImplementedError("User-assistant chat format masking function is not yet implemented.")


format_fn_registry = {
    "pass_text": pass_text,
    "pass_retrieval_pair": pass_retrieval_pair,
    "concat_input_target": concat_input_target,
    "condition_input_supervise_target": condition_input_supervise_target,
    "apply_chat_template_supervise_all": apply_chat_template_supervise_all,
    "apply_chat_template_supervise_assistant": apply_chat_template_supervise_assistant,
}


def apply_formatting(row, tokenizer, add_bos, add_eos, train_group_size=None):
    # pkds, single tensor
    if isinstance(row, torch.Tensor):
        return row, None
    # pkds, tuple of tensors
    if isinstance(row, tuple):
        raise NotImplementedError("Tuple format not supported, but direct tensor pairs planned.")
        return row[0], row[1]
    # hfds, dict with format_fn from data signature
    if isinstance(row, dict):
        # we can locally override the add_bos or add_eos args if they exist in the row's data_signature
        if row["data_signature"].get("add_bos") is not None:
            add_bos = row["data_signature"]["add_bos"]
        if row["data_signature"].get("add_eos") is not None:
            add_eos = row["data_signature"]["add_eos"]

        if row["data_signature"]["format_fn"] == "pass_retrieval_pair":
            return format_fn_registry[row["data_signature"]["format_fn"]](row, tokenizer, add_bos, add_eos, train_group_size=train_group_size)
        else:
            return format_fn_registry[row["data_signature"]["format_fn"]](row, tokenizer, add_bos, add_eos)
    raise ValueError("Row format not recognized.")


def shift_inputs_and_labels(inputs_batch: torch.Tensor, labels_batch: torch.Tensor, tokenizer: Tokenizer):
    seq_len = inputs_batch.shape[1]

    input_ids = inputs_batch[:, 0 : (seq_len - 1)].contiguous().long()
    label_ids = labels_batch[:, 1:(seq_len)].contiguous().long()

    # for the input we need to replace any pad ids with the eos token
    # knowing that they're trailing so they wont contrib to activations
    # but that they do need to be valid indices in the model's embedding layer
    if tokenizer.eos_id is not None:
        input_ids[input_ids == tokenizer.pad_id] = tokenizer.eos_id
    # Note that we are _not_ doing this operation for the labels,
    # since this is where we actually need the pad tokens to be present for loss to ignore them.

    return input_ids, label_ids


def generic_collate_fn(
    batch,
    tokenizer: Tokenizer,
    block_size: Optional[int] = None,
    pad_to_block_size: bool = False,
    add_bos=True,
    add_eos=True,
    collate_checks_enabled=True,
    all_block_size_tensors=False,
    train_group_size=None,
    length_shortcut_ablation=None,
    generator=None,
):
    # random.seed(42 + torch.distributed.get_rank())
    # pricking random numbers
    # print("$$$$$$$$$$Random number: ", random.randint(0, 100), flush=True)
#     metadata = [None] * len(batch)
#     for i, row in enumerate(batch):
#         if isinstance(row, dict) and "data_id" in row:
#             metadata[i] = row["data_id"]

    # If we are only dealing with tensors that we _know_ are the same size,
    # we can just use the default collate_tensor_fn.
    # this is theoretically the fastest codepath.
    # for a bleeding edge pretraining run, this flag should be set to True, all data should be pkds
    # and we do minimal to no processing on the fly.
    if all_block_size_tensors:
        inputs_batch = collate_tensor_fn(batch)
        # labels_batch = inputs_batch.clone()
        # input_ids, label_ids = shift_inputs_and_labels(inputs_batch, labels_batch, tokenizer)
        return inputs_batch, None
    else:
        assert block_size is not None

    # This is O(bsz) but it's a more readable error message than the later failure would be.
    if collate_checks_enabled:
        assert isinstance(batch, list), "Batch must be a list."
        type_list = [type(x) for x in batch]
        allowed_types = [dict, torch.Tensor]
        types_found = set(type_list)
        assert types_found.issubset(allowed_types), "Batch must contain only expected types."

        if dict in types_found:
            assert tokenizer is not None, "If batch contains dicts, tokenizer must be provided."
            assert tokenizer.pad_id is not None, "Tokenizer must have pad token id since we are dynamically padding."
    
    
    # HACK: Using a heuristic to determine if it's single column (.i.e. text) pqds tokenized data
    # Usually this won't be necessary if we run with all_block_size_tensors=True, but this is for the case where we don't.
    if isinstance(batch[0], torch.Tensor):
        # single column data, no need to check for num_keys
        num_keys = 1
    else:
        num_keys = len(batch[0]["data_signature"]["keys"])

    # this takes in a heterogeneous list of rows and returns a batch of tensor pairs.
    batch = [apply_formatting(row, tokenizer, add_bos, add_eos, train_group_size) for row in batch]

    # We operate under the assumption that all rows now have a pair of tensors as their elements.
    # In both cases we'll just declare two tensors bsz x block_size
    # and copy all the input and label tokens into them.
    # but we can unify this logic with pad to longest by setting a local_block_size

    if length_shortcut_ablation:
        if length_shortcut_ablation == "permute_batch_tokens":
            all_prefix_lengths = [len(row[0]) for row in batch]
            packed_input_tokens = torch.cat([row[0] for row in batch], dim=0)

            permuted_input_tokens = packed_input_tokens[torch.randperm(packed_input_tokens.size(0), generator=generator)]
            permuted_batch = list(torch.split(permuted_input_tokens, all_prefix_lengths))
            batch = [(permuted_prefix, None) for permuted_prefix in permuted_batch] # reconstructing the batch with permuted tokens
        elif length_shortcut_ablation == "rand_toks_const_lens":
            length, bsz = 1440, len(batch)
            random_tokens = torch.randint(0, tokenizer.vocab_size, (bsz, length), generator=generator)
            batch = [(random_tokens[i], None) for i in range(bsz)]
        elif length_shortcut_ablation == "rand_toks_doc_lens":
            all_prefix_lengths = [len(row[0]) for row in batch]
            random_tokens = [torch.randint(0, tokenizer.vocab_size, (length,), generator=generator) for length in all_prefix_lengths]
            bsz = len(batch)
            batch = [(random_tokens[i], None) for i in range(bsz)]
        elif length_shortcut_ablation == "rand_toks_rand_lens":
            bsz = len(batch)
            random_lengths = torch.randint(512, 2048, (bsz,), generator=generator)
            random_tokens = [torch.randint(0, tokenizer.vocab_size, (length,), generator=generator) for length in random_lengths]
            batch = [(random_tokens[i], None) for i in range(bsz)]
        elif length_shortcut_ablation == "truncate_lens_100_uniform":
            offset = torch.randint(0, 100, (1,), generator=generator).item()
            batch = [(row[0][:2048 - offset], None) if len(row[0]) >= 2048 else row for row in batch]
        elif length_shortcut_ablation == "truncate_lens_100_normal":
            offset = int(torch.normal(0, 100, (1,), generator=generator).abs().item())
            batch = [(row[0][:2048 - offset], None) if len(row[0]) >= 2048 else row for row in batch]
        else:
            raise NotImplementedError(f"Length shortcut ablation {length_shortcut_ablation} not implemented.")
    
    # Single sequence pretraining
    all_prefix_lengths = [len(row[0]) for row in batch]
    # min against block size since the XXXX-13 realized could be longer than block size.
    local_prefix_block_size = min(XXXX-13(all_prefix_lengths), block_size)

    all_suffix_lengths = None
    num_suffixes = len(batch)

    # Two sequence finetuning
    if num_keys > 1: # means we have a pair of tensors (e.g. question, response)
        
        all_suffix_lengths = [len(row[1]) for row in batch]
        
        # Three sequence finetuning
        # Flatten if suffix list of lists (it'll happen when we have hard negatives, so each query has multiple responses)
        if isinstance(batch[0][1], list):
            all_suffix_lengths = [len(suffix) for row in batch for suffix in row[1]]
            num_suffixes = len(all_suffix_lengths)
        
        local_suffix_block_size = min(XXXX-13(all_suffix_lengths), block_size)

    if pad_to_block_size:
        local_prefix_block_size = block_size
        local_suffix_block_size = block_size


    # # Impl 1: list comp row wise pad, then torch collate fn. (closer to original implementation)
    # # Using torch tensor collation is clever about writing to shm between the data and main process.
    # # But idk if this is actually faster in our setting...
    # inputs_batch = [
    #     torch.tensor(x[0][:local_block_size].tolist() + [tokenizer.pad_id] * (local_block_size - len(x[0])))
    #     for x in batch
    # ]
    # labels_batch = [
    #     torch.tensor(x[1][:local_block_size].tolist() + [tokenizer.pad_id] * (local_block_size - len(x[1])))
    #     for x in batch
    # ]
    # inputs_batch = collate_tensor_fn(inputs_batch)
    # labels_batch = collate_tensor_fn(labels_batch)

    # Impl 2: Full tensor copy version. Simpler to read, and on initial interactive tests, equivalently fast/slow.
    prefix_batch = torch.full((len(batch), local_prefix_block_size), tokenizer.pad_id or 0, dtype=torch.int)
    suffix_batch = torch.full((num_suffixes, local_suffix_block_size), tokenizer.pad_id or 0, dtype=torch.int) if num_keys > 1 else None
    current_suffix_idx = 0
    for i, batch_data in enumerate(batch):
        input_tokens = batch_data[0]
        prefix_batch[i, : len(input_tokens)] = input_tokens[
            :local_prefix_block_size
        ]  # this ensures we don't write past the block size
        label_tokens = batch_data[1] if num_keys > 1 else None
        if label_tokens is not None:
            ########### Logic to handle hard negatives ###########
            if isinstance(label_tokens, list): # this means we have multiple suffixes (e.g. hard negatives)
                for j, suffix in enumerate(label_tokens):
                    suffix_batch[current_suffix_idx, : len(suffix)] = suffix[:local_suffix_block_size]
                    current_suffix_idx += 1
            else:
                suffix_batch[i, : len(label_tokens)] = label_tokens[:local_suffix_block_size]

    # Now all rows are tensors of the same, valid length, <= block_size.

    # We need to check whether the entire batch consists of padding tokens
    if torch.all(prefix_batch == tokenizer.eos_id) or torch.all(prefix_batch == tokenizer.pad_id):
        # if so, we raise a StopIteration to signal the exhaustion of all data sources since
        # no real tokens are present in the batch.
        raise StopIteration("All tokens in batch are padding tokens.")
    
    if suffix_batch is not None:
        if torch.all(suffix_batch == tokenizer.eos_id) or torch.all(suffix_batch == tokenizer.pad_id):
            # if so, we raise a StopIteration to signal the exhaustion of all data sources since
            # no real tokens are present in the batch.
            raise StopIteration("All tokens in batch are padding tokens.")

    # input_ids, label_ids = shift_prefix_and_suffix(prefix_batch, suffix_batch, tokenizer)
    return prefix_batch, suffix_batch