import torch
from torch.utils.data._utils.collate import collate_tensor_fn
import numpy as np
import bisect

def replacement_func(token, vocab_size, device, dtype):
    return torch.randint(0, vocab_size, (1,), device=device, dtype=dtype)

def pad_or_truncate(tensor, pad_to_length, pad_id):
    current_length = tensor.size(0)
    if current_length > pad_to_length:
        # Truncate the tensor if it's longer than the desired length
        return tensor[:pad_to_length]
    elif current_length < pad_to_length:
        # Pad the tensor if it's shorter than the desired length
        padding_size = pad_to_length - current_length
        padding = torch.full((padding_size,), pad_id, dtype=tensor.dtype, device=tensor.device)
        return torch.cat([tensor, padding])
    return tensor 

def random_delete_function(tokenised_list, answer_mask, delete_token, vocab_size, pad_to_length, pad_id, flag_for_nested_delete=False, decay=0.99, p_delete=0.1, max_delete=4, delete_cfg=None):
    """
    answer_mask: boolean tensor same shape as tokenised_list -- True when we can delete a token
    flag_for_nested_delete: if the wrong value can be deleted and replaced with wrong value and redeleted
    """
    ## creating the delete probability
    # this will create the list[list]
    # all_deletes_probability[index] is a list of probability of [prob(no delete), prob(delete=1), ...]
    all_deletes_probability = []
    # all_deletes_probability[0][0] = 1-p_delete , all_deletes_probability[0][1] = p_delete
    # all_deletes_probability[1][0] = 1-p_delete , all_deletes_probability[1][1] = f , all_deletes_probability[1][2] = decay * f
    # f is such f + f*decay = p_delete

    # all_deletes_probability[3][0] = 1-p_delete, all_deletes_probability[3][1] = f, all_deletes_probability[3][2] = decay * f
    # all_deletes_probability[3][3] = decay ** 2 f

    for i in range(max_delete):
        delete_p_list = []
        for j in range(i+1):
            delete_p_list.append(decay**j)
        # normalize the delete_p_list for it to sum to p_delete
        delete_p_list = [p_delete * p/sum(delete_p_list) for p in delete_p_list]
        delete_p_list.insert(0, 1 - sum(delete_p_list))
        # make it cummulative
        for j in range(1, len(delete_p_list)):
            delete_p_list[j] += delete_p_list[j-1]
        all_deletes_probability.append(delete_p_list)

    # random replacement code
    mask_list = []
    output_list = []
    for i in range(len(tokenised_list)):
    # find out what probability structure to have, important for boundary at the end.
        if not answer_mask[i].item(): # i.e. we cannot put a delete here
            output_list.append(tokenised_list[i].unsqueeze(0))
            mask_list.append(0)
            continue

        prob_index = min(len(tokenised_list)-i, max_delete)
        prob_list = all_deletes_probability[prob_index-1]

        rand = np.random.uniform(0, 1)
        num_deletes = bisect.bisect_right(prob_list, rand)
        # since the prob_list was a cummulative function we simply see what it smaller than

        if num_deletes == 0:
            mask_list.append(1)
            output_list.append(tokenised_list[i].unsqueeze(0))
        else:
            next_n_tokens = tokenised_list[i:i+num_deletes]
            for n_token in next_n_tokens:
                mask_list.append(0)
                output_list.append(replacement_func(n_token, vocab_size, tokenised_list.device, tokenised_list.dtype))
            output_list = output_list + [delete_token for _ in range(num_deletes)]
            mask_list = mask_list + [1 for _ in range(num_deletes)]

            # see what to do next
            if flag_for_nested_delete:
                i = i-1
            else:
                output_list.append(next_n_tokens)
                mask_list = mask_list + [1 for _ in range(num_deletes)]
                i = i+num_deletes

    mask_list = torch.tensor(mask_list, device=tokenised_list.device, dtype=torch.bool)
    output_list = torch.cat(output_list, dim=0)

    mask_list = pad_or_truncate(mask_list, pad_to_length, True)
    output_list = pad_or_truncate(output_list, pad_to_length, pad_id)

    return mask_list, output_list

def get_answer_mask(tensor, marker):
    """
    Returns a mask which is true inbetween marker positions (not including marker positions)
    """
    marker_positions = (tensor == marker)
    # Create a cumulative sum that changes state at each occurence of marker
    cumulative_regions = marker_positions.cumsum(0)
    # Increase the mask by one where marker occurs to shift the effect forward
    between_markers = ((cumulative_regions % 2) == 1) & ~marker_positions
    return between_markers

def delete_collate_fn(
    batch,
    tokenizer=None,
    block_size=None,
    pad_to_block_size=False,
    add_bos=True,
    add_eos=True,
    collate_checks_enabled=True,
    all_block_size_tensors=False,
    use_delete={'use_delete': False, 'random': True, 'nested_deletes': False, 'pad_matrix': True, 'use_template': False},
):
    if all_block_size_tensors:
        # If we are only dealing with tensors that we _know_ are the same size,
        # we can just use the default collate_tensor_fn
        collated_batch = collate_tensor_fn(batch)
        return {'data': collated_batch, 'labels': collated_batch.clone()}

    if collate_checks_enabled:
        assert isinstance(batch, list), "Batch must be a list."
        type_list = [type(x) for x in batch]
        if str in type_list:
            assert tokenizer is not None, "If batch contains strings, tokenizer must be provided."
            assert tokenizer.pad_id is not None, "Tokenizer must have pad token id since we are dynamically padding."

    # if tokenizer is not None:
    # for now, we assume that if we need it, the tokenizer is always present
    if use_delete["use_template"]:
        temp_batch = []
        answer_masks = []
        for row in batch:
            if type(row) != list:
                raise NotImplementedError()
            string_format = tokenizer.processor.apply_chat_template(row, tokenize=False) # apply chat template
            tokenized_format = tokenizer.encode(string_format, bos=add_bos, eos=add_eos) # tokenise
            answer_mask = get_answer_mask(tokenized_format, tokenizer.sep_id) # get the answer mask, True if the token is part of the answer
            sep_mask = (tokenized_format != tokenizer.sep_id)
            temp_batch.append(tokenized_format[sep_mask]) # remove the sep tokens out of the sample
            answer_masks.append(answer_mask[sep_mask])
        batch = temp_batch
    else:
        batch = [tokenizer.encode(row, bos=add_bos, eos=add_eos) if type(row) == str else row for row in batch]

    delete_id_exists = tokenizer.delete_id is not None
    if delete_id_exists: # simple random deletes for now
        if use_delete["random"]:
            delete_function = random_delete_function
        else:
            raise NotImplementedError()

        vocab_size = tokenizer.vocab_size
        delete_token = torch.tensor([tokenizer.delete_id], dtype=batch[0].dtype, device=batch[0].device)
        delete_masks = []
        for i, row in enumerate(batch):
            if use_delete["use_template"]:
                # delete_function takes care of masking the false parts of answer mask for us
                mask, tokens = delete_function(row, answer_masks[i], delete_token, vocab_size, pad_to_length=block_size, pad_id=tokenizer.pad_id, flag_for_nested_delete=use_delete["nested_deletes"], delete_cfg=use_delete)
            else:
                answer_mask = torch.ones_like(row, dtype=torch.bool, device=row.device) # no question/answer divide, so can put deletes anywhere
                mask, tokens = delete_function(row, answer_mask, delete_token, vocab_size, pad_to_length=block_size, pad_id=tokenizer.pad_id, flag_for_nested_delete=use_delete["nested_deletes"], delete_cfg=use_delete)
            batch[i] = tokens
            delete_masks.append(mask)
        collated_batch = collate_tensor_fn(batch)[:, :block_size]
        collated_delete_masks = collate_tensor_fn(delete_masks)[:, :block_size]
        
        if torch.all(collated_batch == tokenizer.pad_id):
            raise StopIteration("All tokens in batch are padding tokens.")

        collated_delete_labels = collated_batch.clone()
        collated_delete_labels[collated_delete_masks] = tokenizer.pad_id
        # tokenizer.pad_id is the ignore index in the CE loss (References: line 228 of XXXX-40/train.py, line 255 of XXXX-40/litgpt/utils.py)
        # so we hand back the labels, with the tokens which would be deleted as pad_ids as so they are ignored in the loss
        # the labels are then shifted on line 372 of XXXX-40/train.py
        return {'data': collated_batch, 'labels': collated_delete_labels}
    

    # Now all rows are tokenized
    # logic is a bit generic, could be tightened under encode -> tensor assumption
    if pad_to_block_size:
        batch = [torch.tensor(x[:block_size].tolist() + [tokenizer.pad_id] * (block_size - len(x))) for x in batch]
    else:
        # pad to longest in batch
        max_len = XXXX-13(len(x) for x in batch)
        batch = [torch.tensor(x.tolist() + [tokenizer.pad_id] * (max_len - len(x))) for x in batch]

    # Now all rows are tensors of the same length.
    # Always slice to block size since the XXXX-13 row length realized could be longer than block size.
    collated_batch = collate_tensor_fn(batch)[:, :block_size]

    # We need to check whether the entire batch consists of padding tokens
    # if so, we raise a StopIteration to signal the exhaustion of all data sources since
    # no real tokens are present in the batch
    if torch.all(collated_batch == tokenizer.pad_id):
        raise StopIteration("All tokens in batch are padding tokens.")

    return {'data': collated_batch, 'labels': collated_batch.clone()}
