import torch

def collate_fn_everything(batch, use_corrupted_activations, tokenizer, padding_left=False):
    if padding_left:
        tokenizer.padding_side = "left"
    else:
        tokenizer.padding_side = "right"
    contexts, targets = zip(*[(item["context"], item["target"]) for item in batch])
    contexts_with_targets = [context + target for context, target in zip(contexts, targets)]
    if use_corrupted_activations:
        corrupted_contexts, corrupted_targets = zip(*[(item["corrupted_context"], item["corrupted_target"]) for item in batch])
        corrupted_contexts_with_targets = [context + target for context, target in zip(corrupted_contexts, corrupted_targets)]

    tokenized_contexts = tokenizer(contexts, padding=True, return_tensors="pt")
    tokenzied_contexts_with_targets = tokenizer(contexts_with_targets, padding=True, return_tensors="pt")
    special_token = "<special-token-we-will-never-encounter-in-a-dataset>"
    tokenizer.add_tokens([special_token])
    full_tokenized_targets = torch.hstack([tokenizer([special_token + target], return_tensors="pt", add_special_tokens=False)["input_ids"][:, 1:] for target in targets]).squeeze(0)
    tokenized_targets = tokenizer([special_token + target for target in targets], return_tensors="pt", padding=True, add_special_tokens=False)["input_ids"][:, 1]

    if use_corrupted_activations:
        tokenized_corrupted_contexts = tokenizer(corrupted_contexts, padding=True, return_tensors="pt")
        tokenized_corrupted_contexts_with_targets =  tokenizer(corrupted_contexts_with_targets, padding=True, return_tensors="pt")
    else:
        tokenized_corrupted_contexts = None
        tokenized_corrupted_contexts_with_targets = None
    
    lens = [(tokenized_contexts["input_ids"][ctx_i, :] != tokenizer.pad_token_id).sum() for ctx_i in range(tokenized_contexts["input_ids"].shape[0])]
    return tokenized_contexts, tokenized_targets, torch.tensor(lens), targets, \
           tokenized_corrupted_contexts, tokenzied_contexts_with_targets, \
           tokenized_corrupted_contexts_with_targets, full_tokenized_targets

def collate_fn_many_corruptions_per_doc(batch, tokenizer, padding_left=False):
    if padding_left:
        tokenizer.padding_side = "left"
    else:
        tokenizer.padding_side = "right"
    contexts, targets = zip(*[(item["context"], item["target"]) for item in batch])
    contexts_with_targets = [context + target for context, target in zip(contexts, targets)]

    many_corrupted_contexts, corrupted_targets = zip(*[(item["corrupted_contexts"], item["corrupted_target"]) for item in batch])
    corrupted_contexts = [[many_corrupted_contexts[i][j] for i in range(len(many_corrupted_contexts))] for j in range(len(many_corrupted_contexts[0]))]
    corrupted_contexts_with_targets = [[context + target
                                        for context, target in zip(corrupted_contexts[i], corrupted_targets)]
                                        for i in range(len(corrupted_contexts))]

    tokenized_contexts = tokenizer(contexts, padding=True, return_tensors="pt")
    tokenzied_contexts_with_targets = tokenizer(contexts_with_targets, padding=True, return_tensors="pt")
    special_token = "<special-token-we-will-never-encounter-in-a-dataset>"
    tokenizer.add_tokens([special_token])
    full_tokenized_targets = torch.hstack([tokenizer([special_token + target], return_tensors="pt", add_special_tokens=False)["input_ids"][:, 1:] for target in targets]).squeeze(0)
    tokenized_targets = tokenizer([special_token + target for target in targets], return_tensors="pt", padding=True, add_special_tokens=False)["input_ids"][:, 1]

    tokenized_corrupted_contexts_per_fewshot = [tokenizer(context_batch, padding=True, return_tensors="pt") for context_batch in corrupted_contexts]
    tokenized_corrupted_contexts_with_targets_per_fewshot =  [tokenizer(context_batch, padding=True, return_tensors="pt") for context_batch in corrupted_contexts_with_targets]
    
    lens = [(tokenized_contexts["input_ids"][ctx_i, :] != tokenizer.pad_token_id).sum() for ctx_i in range(tokenized_contexts["input_ids"].shape[0])]
    return tokenized_contexts, tokenized_targets, torch.tensor(lens), targets, \
           tokenized_corrupted_contexts_per_fewshot, tokenzied_contexts_with_targets, \
           tokenized_corrupted_contexts_with_targets_per_fewshot, full_tokenized_targets