import os
from pathlib import Path
from typing import Callable, List, Dict
from datasets import load_from_disk
import torch
import random
import numpy as np

from patching_gemma.tasks.task import Task, TaskConfig

class FineGrainedTypesDifferentCorruptionsTask(Task):
    ALLOWED = {
        "task_name": ["capitalization", "person_sport", "country_capital", "present_past", "copy"],
        "corruption": [
                        "x_in_all_fewshots_agreed", "y_in_all_fewshots_agreed", "x_y_in_all_fewshots_agreed",
                        "x_within_input_space_in_all_fewshots_agreed",
                        "y_within_input_space_in_all_fewshots_agreed",
                        "full_corruption_in_all_fewshots_agreed",
                        "x_within_input_space_in_all_fewshots_but_current_agreed",
                        "x_in_all_fewshots_but_current_agreed",
                        "y_in_all_fewshots_but_current_agreed",
                        "y_in_all_fewshots_but_current_and_previous_agreed",
                        "y_in_all_fewshots_but_current_and_2_previous_agreed",
                        "y_within_input_space_in_all_fewshots_but_current_agreed",
                        "y_within_input_space_in_all_fewshots_but_current_and_previous_agreed",
                        "y_within_input_space_in_all_fewshots_but_current_and_2_previous_agreed",
        ],
        "n_shot": [3, 10],
    }
    
    def __init__(self, task_name, n_shot, corruption):
        assert task_name in FineGrainedTypesDifferentCorruptionsTask.ALLOWED["task_name"]
        assert corruption in FineGrainedTypesDifferentCorruptionsTask.ALLOWED["corruption"]
        self.dataset_name = task_name
        self.doc_to_text: Callable = lambda doc : str(doc["input"])
        self.doc_to_target: Callable = lambda doc : str(doc["target"])
        self.fewshot_space: str = "\n"
        self.sep: str = "\t"
        self.loss: Callable = self.loss_function
        self.fewshot: int = n_shot
        self.fewshot_seed: int = 42
        self.can_be_token_separable: bool = True
        self.corruption = corruption
        self.NUM_CORRUPTIONS = {
                                   "x_in_all_fewshots_agreed": 1,
                                   "y_in_all_fewshots_agreed": 1,
                                   "x_y_in_all_fewshots_agreed": 1,
                                   "x_within_input_space_in_all_fewshots_agreed": 1,
                                   "y_within_input_space_in_all_fewshots_agreed": 1,
                                   "full_corruption_in_all_fewshots_agreed": 1,
                                   "x_within_input_space_in_all_fewshots_but_current_agreed": self.fewshot,
                                   "x_in_all_fewshots_but_current_agreed": self.fewshot,
                                   "y_in_all_fewshots_but_current_agreed": self.fewshot,
                                    "y_in_all_fewshots_but_current_and_previous_agreed": self.fewshot,
                                    "y_in_all_fewshots_but_current_and_2_previous_agreed": self.fewshot,
                                    "y_within_input_space_in_all_fewshots_but_current_agreed": self.fewshot,
                                    "y_within_input_space_in_all_fewshots_but_current_and_previous_agreed": self.fewshot,
                                    "y_within_input_space_in_all_fewshots_but_current_and_2_previous_agreed": self.fewshot,
                               }[self.corruption]
        self.dataset = load_from_disk(self.config.dataset)
        self.TOKEN_TYPES = 16 if self.fewshot == 3 else (44 if self.fewshot == 10 else -1)
        assert self.TOKEN_TYPES != -1

    @staticmethod
    def loss_function(logits, targets, lens, tp_inds):
        predictive_inds = ((tp_inds == FineGrainedTypesDifferentCorruptionsTask.TARGET_TYPE) |
                          (tp_inds == FineGrainedTypesDifferentCorruptionsTask.LAST_SEP_TYPE))
        predictive_inds[torch.arange(predictive_inds.shape[0]), (predictive_inds.cumsum(dim=-1) * predictive_inds).argmax(dim=-1)] = 0
        return torch.nn.functional.cross_entropy(logits[predictive_inds, :], targets)

    @staticmethod
    def get_name(task_name, n_shot, corruption):
        return f"{task_name}_{n_shot}_shot_{corruption}"
        
    @property
    def config(self) -> TaskConfig:
        return TaskConfig(
            dataset=str(Path(os.path.dirname(os.path.realpath(__file__))).parent.joinpath("datasets").joinpath(self.dataset_name)),
            subset=None,
            name=FineGrainedTypesDifferentCorruptionsTask.get_name(self.dataset_name, self.fewshot, self.corruption),
            evaluation_split="test",
            fewshot_split="train"
        )


    X_Y_CORRUPTION_PER_DATASET = {
        "country_capital": {
            "regular": {(1, 1): [{'input': 'Sweden', 'target': 'Stockholm'}, {'input': 'Norway', 'target': 'Oslo'}, {'input': 'Switzerland', 'target': 'Bern'}, {'input': 'Bulgaria', 'target': 'Sofia'}, {'input': 'Japan', 'target': 'Tokyo'}, {'input': 'Turkey', 'target': 'Ankara'}, {'input': 'Ireland', 'target': 'Dublin'}, {'input': 'Poland', 'target': 'Warsaw'}, {'input': 'Austria', 'target': 'Vienna'}], (2, 6): [{'input': 'Brunei', 'target': 'Bandar Seri Begawan'}], (2, 5): [{'input': 'Grenada', 'target': "St. George's"}], (1, 2): [{'input': 'Uruguay', 'target': 'Montevideo'}, {'input': 'Slovakia', 'target': 'Bratislava'}, {'input': 'Panama', 'target': 'Panama City'}, {'input': 'Bahrain', 'target': 'Manama'}, {'input': 'Tunisia', 'target': 'Tunis'}, {'input': 'Australia', 'target': 'Canberra'}, {'input': 'Senegal', 'target': 'Dakar'}, {'input': 'Jordan', 'target': 'Amman'}, {'input': 'Croatia', 'target': 'Zagreb'}, {'input': 'Georgia', 'target': 'Tbilisi'}, {'input': 'Argentina', 'target': 'Buenos Aires'}], (1, 3): [{'input': 'Tanzania', 'target': 'Dodoma'}, {'input': 'Malawi', 'target': 'Lilongwe'}, {'input': 'Laos', 'target': 'Vientiane'}, {'input': 'Congo', 'target': 'Kinshasa'}, {'input': 'Malaysia', 'target': 'Kuala Lumpur'}], (5, 3): [{'input': 'Democratic Republic of the Congo', 'target': 'Kinshasa'}], (2, 2): [{'input': 'North Korea', 'target': 'Pyongyang'}, {'input': 'Dominica', 'target': 'Roseau'}, {'input': 'Sri Lanka', 'target': 'Colombo'}, {'input': 'Vanuatu', 'target': 'Port Vila'}, {'input': 'Gambia', 'target': 'Banjul'}], (3, 2): [{'input': 'United Arab Emirates', 'target': 'Abu Dhabi'}, {'input': 'Central African Republic', 'target': 'Bangui'}], (2, 1): [{'input': 'United Kingdom', 'target': 'London'}, {'input': 'Bahamas', 'target': 'Nassau'}], (2, 3): [{'input': 'Moldova', 'target': 'Chisinau'}, {'input': 'Montenegro', 'target': 'Podgorica'}]},
            "in_beginning": {(2, 1): [{'input': 'Sweden', 'target': 'Stockholm'}, {'input': 'Norway', 'target': 'Oslo'}, {'input': 'Switzerland', 'target': 'Bern'}, {'input': 'Bulgaria', 'target': 'Sofia'}, {'input': 'Japan', 'target': 'Tokyo'}, {'input': 'Turkey', 'target': 'Ankara'}, {'input': 'Ireland', 'target': 'Dublin'}, {'input': 'Poland', 'target': 'Warsaw'}, {'input': 'Austria', 'target': 'Vienna'}], (3, 6): [{'input': 'Brunei', 'target': 'Bandar Seri Begawan'}], (3, 5): [{'input': 'Grenada', 'target': "St. George's"}], (2, 2): [{'input': 'Uruguay', 'target': 'Montevideo'}, {'input': 'Slovakia', 'target': 'Bratislava'}, {'input': 'Panama', 'target': 'Panama City'}, {'input': 'Bahrain', 'target': 'Manama'}, {'input': 'Tunisia', 'target': 'Tunis'}, {'input': 'Australia', 'target': 'Canberra'}, {'input': 'Senegal', 'target': 'Dakar'}, {'input': 'Jordan', 'target': 'Amman'}, {'input': 'Croatia', 'target': 'Zagreb'}, {'input': 'Georgia', 'target': 'Tbilisi'}, {'input': 'Argentina', 'target': 'Buenos Aires'}], (2, 3): [{'input': 'Tanzania', 'target': 'Dodoma'}, {'input': 'Malawi', 'target': 'Lilongwe'}, {'input': 'Laos', 'target': 'Vientiane'}, {'input': 'Congo', 'target': 'Kinshasa'}, {'input': 'Malaysia', 'target': 'Kuala Lumpur'}], (6, 3): [{'input': 'Democratic Republic of the Congo', 'target': 'Kinshasa'}], (3, 2): [{'input': 'North Korea', 'target': 'Pyongyang'}, {'input': 'Dominica', 'target': 'Roseau'}, {'input': 'Sri Lanka', 'target': 'Colombo'}, {'input': 'Vanuatu', 'target': 'Port Vila'}, {'input': 'Gambia', 'target': 'Banjul'}], (4, 2): [{'input': 'United Arab Emirates', 'target': 'Abu Dhabi'}, {'input': 'Central African Republic', 'target': 'Bangui'}], (3, 1): [{'input': 'United Kingdom', 'target': 'London'}, {'input': 'Bahamas', 'target': 'Nassau'}], (3, 3): [{'input': 'Moldova', 'target': 'Chisinau'}, {'input': 'Montenegro', 'target': 'Podgorica'}]},
        },
        "capitalization": {
            "regular": {(1, 2): [{'input': 'exempt', 'target': 'Exempt'}, {'input': 'amide', 'target': 'Amide'}, {'input': 'capes', 'target': 'Capes'}, {'input': 'announce', 'target': 'Announce'}, {'input': 'plaintext', 'target': 'Plaintext'}], (2, 2): [{'input': 'norsemen', 'target': 'Norsemen'}, {'input': 'keysters', 'target': 'Keysters'}, {'input': 'tamlung', 'target': 'Tamlung'}, {'input': 'beeswings', 'target': 'Beeswings'}, {'input': 'misreference', 'target': 'Misreference'}], (4, 4): [{'input': 'ferronickel', 'target': 'Ferronickel'}, {'input': 'undecisively', 'target': 'Undecisively'}, {'input': 'chamaerops', 'target': 'Chamaerops'}, {'input': 'epidemiographist', 'target': 'Epidemiographist'}, {'input': 'ebracteate', 'target': 'Ebracteate'}], (4, 3): [{'input': 'triozonide', 'target': 'Triozonide'}, {'input': 'slapdashes', 'target': 'Slapdashes'}, {'input': 'ribbandry', 'target': 'Ribbandry'}, {'input': 'staphylohemia', 'target': 'Staphylohemia'}, {'input': 'pentadecahydrate', 'target': 'Pentadecahydrate'}], (3, 3): [{'input': 'nothosaurus', 'target': 'Nothosaurus'}, {'input': 'coranoch', 'target': 'Coranoch'}, {'input': 'socotri', 'target': 'Socotri'}, {'input': 'saponify', 'target': 'Saponify'}, {'input': 'linalool', 'target': 'Linalool'}], (2, 3): [{'input': 'hetaery', 'target': 'Hetaery'}, {'input': 'entothorax', 'target': 'Entothorax'}, {'input': 'codomain', 'target': 'Codomain'}, {'input': 'spurwinged', 'target': 'Spurwinged'}, {'input': 'interruptedness', 'target': 'Interruptedness'}], (3, 4): [{'input': 'dramalogue', 'target': 'Dramalogue'}, {'input': 'unreprovably', 'target': 'Unreprovably'}, {'input': 'nonsolvability', 'target': 'Nonsolvability'}, {'input': 'cachibou', 'target': 'Cachibou'}, {'input': 'pseudoscarus', 'target': 'Pseudoscarus'}], (4, 5): [{'input': 'preadjectival', 'target': 'Preadjectival'}, {'input': 'vapourization', 'target': 'Vapourization'}, {'input': 'pernicketiness', 'target': 'Pernicketiness'}, {'input': 'tersulphuret', 'target': 'Tersulphuret'}, {'input': 'ombudsperson', 'target': 'Ombudsperson'}], (1, 1): [{'input': 'generators', 'target': 'Generators'}, {'input': 'broker', 'target': 'Broker'}, {'input': 'nail', 'target': 'Nail'}, {'input': 'covered', 'target': 'Covered'}, {'input': 'where', 'target': 'Where'}], (3, 2): [{'input': 'lobstering', 'target': 'Lobstering'}, {'input': 'taotai', 'target': 'Taotai'}, {'input': 'pronota', 'target': 'Pronota'}, {'input': 'dollyman', 'target': 'Dollyman'}, {'input': 'monarchally', 'target': 'Monarchally'}], (6, 6): [{'input': 'calciovolborthite', 'target': 'Calciovolborthite'}, {'input': 'cholecystostomies', 'target': 'Cholecystostomies'}, {'input': 'asymmetranthous', 'target': 'Asymmetranthous'}, {'input': 'uncontumaciousness', 'target': 'Uncontumaciousness'}, {'input': 'dodecasyllabic', 'target': 'Dodecasyllabic'}], (5, 5): [{'input': 'rhipipterous', 'target': 'Rhipipterous'}, {'input': 'ventriloquised', 'target': 'Ventriloquised'}, {'input': 'syphiliphobia', 'target': 'Syphiliphobia'}, {'input': 'redissolubleness', 'target': 'Redissolubleness'}, {'input': 'ochlophobist', 'target': 'Ochlophobist'}], (2, 1): [{'input': 'jews', 'target': 'Jews'}, {'input': 'internation', 'target': 'Internation'}, {'input': 'barton', 'target': 'Barton'}, {'input': 'analyzed', 'target': 'Analyzed'}, {'input': 'sair', 'target': 'Sair'}], (2, 4): [{'input': 'coprocessing', 'target': 'Coprocessing'}], (5, 4): [{'input': 'labretifery', 'target': 'Labretifery'}, {'input': 'hymenogastraceae', 'target': 'Hymenogastraceae'}, {'input': 'rupicaprinae', 'target': 'Rupicaprinae'}, {'input': 'spirillotropism', 'target': 'Spirillotropism'}, {'input': 'astragalotibial', 'target': 'Astragalotibial'}], (7, 7): [{'input': 'pericardiosymphysis', 'target': 'Pericardiosymphysis'}, {'input': 'uranostaphylorrhaphy', 'target': 'Uranostaphylorrhaphy'}], (5, 6): [{'input': 'anhaematopoiesis', 'target': 'Anhaematopoiesis'}, {'input': 'tribromoacetaldehyde', 'target': 'Tribromoacetaldehyde'}, {'input': 'cingulectomies', 'target': 'Cingulectomies'}], (1, 3): [{'input': 'sistence', 'target': 'Sistence'}]},
            "in_beginning": {(2, 2): [{'input': 'exempt', 'target': 'Exempt'}, {'input': 'amide', 'target': 'Amide'}, {'input': 'capes', 'target': 'Capes'}, {'input': 'announce', 'target': 'Announce'}, {'input': 'plaintext', 'target': 'Plaintext'}], (3, 2): [{'input': 'norsemen', 'target': 'Norsemen'}, {'input': 'keysters', 'target': 'Keysters'}, {'input': 'tamlung', 'target': 'Tamlung'}, {'input': 'beeswings', 'target': 'Beeswings'}, {'input': 'misreference', 'target': 'Misreference'}], (5, 4): [{'input': 'ferronickel', 'target': 'Ferronickel'}, {'input': 'undecisively', 'target': 'Undecisively'}, {'input': 'chamaerops', 'target': 'Chamaerops'}, {'input': 'epidemiographist', 'target': 'Epidemiographist'}, {'input': 'ebracteate', 'target': 'Ebracteate'}], (5, 3): [{'input': 'triozonide', 'target': 'Triozonide'}, {'input': 'slapdashes', 'target': 'Slapdashes'}, {'input': 'ribbandry', 'target': 'Ribbandry'}, {'input': 'staphylohemia', 'target': 'Staphylohemia'}, {'input': 'pentadecahydrate', 'target': 'Pentadecahydrate'}], (4, 3): [{'input': 'nothosaurus', 'target': 'Nothosaurus'}, {'input': 'coranoch', 'target': 'Coranoch'}, {'input': 'socotri', 'target': 'Socotri'}, {'input': 'saponify', 'target': 'Saponify'}, {'input': 'linalool', 'target': 'Linalool'}], (3, 3): [{'input': 'hetaery', 'target': 'Hetaery'}, {'input': 'entothorax', 'target': 'Entothorax'}, {'input': 'codomain', 'target': 'Codomain'}, {'input': 'spurwinged', 'target': 'Spurwinged'}, {'input': 'interruptedness', 'target': 'Interruptedness'}], (4, 4): [{'input': 'dramalogue', 'target': 'Dramalogue'}, {'input': 'unreprovably', 'target': 'Unreprovably'}, {'input': 'nonsolvability', 'target': 'Nonsolvability'}, {'input': 'cachibou', 'target': 'Cachibou'}, {'input': 'pseudoscarus', 'target': 'Pseudoscarus'}], (5, 5): [{'input': 'preadjectival', 'target': 'Preadjectival'}, {'input': 'vapourization', 'target': 'Vapourization'}, {'input': 'pernicketiness', 'target': 'Pernicketiness'}, {'input': 'tersulphuret', 'target': 'Tersulphuret'}, {'input': 'ombudsperson', 'target': 'Ombudsperson'}], (2, 1): [{'input': 'generators', 'target': 'Generators'}, {'input': 'broker', 'target': 'Broker'}, {'input': 'nail', 'target': 'Nail'}, {'input': 'covered', 'target': 'Covered'}, {'input': 'where', 'target': 'Where'}], (4, 2): [{'input': 'lobstering', 'target': 'Lobstering'}, {'input': 'taotai', 'target': 'Taotai'}, {'input': 'pronota', 'target': 'Pronota'}, {'input': 'dollyman', 'target': 'Dollyman'}, {'input': 'monarchally', 'target': 'Monarchally'}], (7, 6): [{'input': 'calciovolborthite', 'target': 'Calciovolborthite'}, {'input': 'cholecystostomies', 'target': 'Cholecystostomies'}, {'input': 'asymmetranthous', 'target': 'Asymmetranthous'}, {'input': 'uncontumaciousness', 'target': 'Uncontumaciousness'}, {'input': 'dodecasyllabic', 'target': 'Dodecasyllabic'}], (6, 5): [{'input': 'rhipipterous', 'target': 'Rhipipterous'}, {'input': 'ventriloquised', 'target': 'Ventriloquised'}, {'input': 'syphiliphobia', 'target': 'Syphiliphobia'}, {'input': 'redissolubleness', 'target': 'Redissolubleness'}, {'input': 'ochlophobist', 'target': 'Ochlophobist'}], (3, 1): [{'input': 'jews', 'target': 'Jews'}, {'input': 'internation', 'target': 'Internation'}, {'input': 'barton', 'target': 'Barton'}, {'input': 'analyzed', 'target': 'Analyzed'}, {'input': 'sair', 'target': 'Sair'}], (3, 4): [{'input': 'coprocessing', 'target': 'Coprocessing'}], (6, 4): [{'input': 'labretifery', 'target': 'Labretifery'}, {'input': 'hymenogastraceae', 'target': 'Hymenogastraceae'}, {'input': 'rupicaprinae', 'target': 'Rupicaprinae'}, {'input': 'spirillotropism', 'target': 'Spirillotropism'}, {'input': 'astragalotibial', 'target': 'Astragalotibial'}], (8, 7): [{'input': 'pericardiosymphysis', 'target': 'Pericardiosymphysis'}, {'input': 'uranostaphylorrhaphy', 'target': 'Uranostaphylorrhaphy'}], (6, 6): [{'input': 'anhaematopoiesis', 'target': 'Anhaematopoiesis'}, {'input': 'tribromoacetaldehyde', 'target': 'Tribromoacetaldehyde'}, {'input': 'cingulectomies', 'target': 'Cingulectomies'}], (2, 3): [{'input': 'sistence', 'target': 'Sistence'}]},
        },
        "person_sport": {
            "regular": {(2, 1): [{'input': 'Alessandro Nesta', 'target': 'soccer'}, {'input': 'Mickey Mantle', 'target': 'baseball'}, {'input': 'Maurice Richard', 'target': 'hockey'}, {'input': 'Pete Rose', 'target': 'baseball'}, {'input': 'Kevin Love', 'target': 'basketball'}, {'input': 'Bob Gibson', 'target': 'baseball'}, {'input': 'Jack Kemp', 'target': 'football'}, {'input': 'Bart Starr', 'target': 'football'}, {'input': 'Byron White', 'target': 'football'}, {'input': 'Phil Jackson', 'target': 'basketball'}, {'input': 'Bobby Hull', 'target': 'hockey'}, {'input': 'Kaká', 'target': 'soccer'}, {'input': 'Cy Young', 'target': 'baseball'}, {'input': 'Dean Cain', 'target': 'football'}, {'input': 'Tim Duncan', 'target': 'basketball'}, {'input': 'Claudio Reyna', 'target': 'soccer'}, {'input': 'Tom Brady', 'target': 'football'}, {'input': 'Ernie Banks', 'target': 'baseball'}, {'input': 'Jerry West', 'target': 'basketball'}, {'input': 'Ty Cobb', 'target': 'baseball'}, {'input': 'Bill Goldberg', 'target': 'football'}, {'input': 'Terry Bradshaw', 'target': 'football'}, {'input': 'David Beckham', 'target': 'soccer'}, {'input': 'Elton Brand', 'target': 'basketball'}, {'input': 'Fred Williamson', 'target': 'football'}, {'input': 'Red Grange', 'target': 'football'}, {'input': 'Yao Ming', 'target': 'basketball'}, {'input': 'Bo Jackson', 'target': 'baseball'}, {'input': 'Patrick Ewing', 'target': 'basketball'}, {'input': 'Nani', 'target': 'soccer'}, {'input': 'Abel Xavier', 'target': 'soccer'}, {'input': 'Kevin Garnett', 'target': 'basketball'}, {'input': 'Matteo Ferrari', 'target': 'soccer'}, {'input': 'Babe Ruth', 'target': 'baseball'}, {'input': 'Kobe Bryant', 'target': 'basketball'}, {'input': 'Roman Reigns', 'target': 'football'}, {'input': 'David Robinson', 'target': 'basketball'}, {'input': 'Corey Perry', 'target': 'hockey'}, {'input': 'Rick Fox', 'target': 'basketball'}, {'input': 'Tim Cahill', 'target': 'soccer'}, {'input': 'Nelson Valdez', 'target': 'soccer'}, {'input': 'Frank Lampard', 'target': 'soccer'}, {'input': 'Billy Sunday', 'target': 'baseball'}, {'input': 'Terry Crews', 'target': 'football'}, {'input': 'Mauricio Wright', 'target': 'soccer'}, {'input': 'Wayne Rooney', 'target': 'soccer'}, {'input': 'Cam Newton', 'target': 'football'}, {'input': 'Bernie Casey', 'target': 'football'}, {'input': 'Alex Rodriguez', 'target': 'baseball'}, {'input': 'Walter Payton', 'target': 'football'}, {'input': 'Bill Russell', 'target': 'basketball'}, {'input': 'Arne Friedrich', 'target': 'soccer'}, {'input': 'Brett Favre', 'target': 'football'}, {'input': 'Bill Bradley', 'target': 'basketball'}, {'input': 'Roberto Clemente', 'target': 'baseball'}, {'input': 'Karl Malone', 'target': 'basketball'}, {'input': 'Gary Carter', 'target': 'baseball'}, {'input': 'Roger Maris', 'target': 'baseball'}, {'input': 'Robbie Rogers', 'target': 'soccer'}, {'input': 'Chris Paul', 'target': 'basketball'}, {'input': 'Javier Hernández', 'target': 'soccer'}, {'input': 'Jim Thorpe', 'target': 'baseball'}, {'input': 'Moe Berg', 'target': 'baseball'}, {'input': 'Len Ford', 'target': 'football'}, {'input': 'Gary Payton', 'target': 'basketball'}, {'input': 'Rick Barry', 'target': 'basketball'}, {'input': 'Moses Malone', 'target': 'basketball'}, {'input': 'Steven Gerrard', 'target': 'soccer'}, {'input': 'Steve Nash', 'target': 'basketball'}, {'input': 'David Carney', 'target': 'soccer'}, {'input': 'Ernie Davis', 'target': 'football'}, {'input': 'Chuck Connors', 'target': 'baseball'}, {'input': 'Brock Lesnar', 'target': 'football'}, {'input': 'Tom Harmon', 'target': 'football'}, {'input': 'Tony Parker', 'target': 'basketball'}, {'input': 'Carlos Arroyo', 'target': 'basketball'}, {'input': 'Phil Esposito', 'target': 'hockey'}], (4, 1): [{'input': 'Ilya Kovalchuk', 'target': 'hockey'}, {'input': 'Willie Stargell', 'target': 'baseball'}, {'input': 'Jermain Defoe', 'target': 'soccer'}, {'input': 'Torsten Frings', 'target': 'soccer'}, {'input': 'Jari Kurri', 'target': 'hockey'}, {'input': 'Marco Di Vaio', 'target': 'soccer'}, {'input': 'Dirk Nowitzki', 'target': 'basketball'}, {'input': 'Lothar Matthäus', 'target': 'soccer'}, {'input': 'Dražen Petrović', 'target': 'basketball'}, {'input': 'Antonio Nocerino', 'target': 'soccer'}, {'input': 'John Matuszak', 'target': 'football'}, {'input': 'Joe DiMaggio', 'target': 'baseball'}, {'input': 'Ichiro Suzuki', 'target': 'baseball'}, {'input': 'Ivan Hlinka', 'target': 'hockey'}, {'input': 'Júlio César', 'target': 'soccer'}, {'input': 'Colin Kaepernick', 'target': 'football'}, {'input': 'Mikaël Silvestre', 'target': 'soccer'}, {'input': 'Freddie Ljungberg', 'target': 'soccer'}, {'input': 'Zdeno Chára', 'target': 'hockey'}, {'input': 'John Olerud', 'target': 'baseball'}, {'input': 'Benny Feilhaber', 'target': 'soccer'}, {'input': 'Raimo Helminen', 'target': 'hockey'}, {'input': 'Obafemi Martins', 'target': 'soccer'}, {'input': 'Scottie Pippen', 'target': 'basketball'}, {'input': 'Igor Larionov', 'target': 'hockey'}, {'input': 'Bashkim Kadrii', 'target': 'soccer'}, {'input': 'Pavel Datsyuk', 'target': 'hockey'}, {'input': 'Jean Béliveau', 'target': 'hockey'}, {'input': 'Siem de Jong', 'target': 'soccer'}, {'input': 'Frank Mahovlich', 'target': 'hockey'}, {'input': 'Howie Morenz', 'target': 'hockey'}, {'input': 'Sergei Fedorov', 'target': 'hockey'}, {'input': 'Steve Yzerman', 'target': 'hockey'}, {'input': 'Pavol Demitra', 'target': 'hockey'}, {'input': 'Carlos Beltrán', 'target': 'baseball'}, {'input': 'Kasey Keller', 'target': 'soccer'}, {'input': 'Jürgen Locadia', 'target': 'soccer'}, {'input': 'Marián Hossa', 'target': 'hockey'}], (3, 1): [{'input': 'Boris Diaw', 'target': 'basketball'}, {'input': 'Dennis Rodman', 'target': 'basketball'}, {'input': 'Kenny Lofton', 'target': 'baseball'}, {'input': 'Luis Scola', 'target': 'basketball'}, {'input': 'Wilt Chamberlain', 'target': 'basketball'}, {'input': 'Alex Karras', 'target': 'football'}, {'input': 'Kendall Waston', 'target': 'soccer'}, {'input': 'Don Shula', 'target': 'football'}, {'input': 'Casey Stengel', 'target': 'baseball'}, {'input': 'Paul Kariya', 'target': 'hockey'}, {'input': 'Woody Strode', 'target': 'football'}, {'input': 'Jacques Plante', 'target': 'hockey'}, {'input': 'Tony Dungy', 'target': 'football'}, {'input': 'Deion Sanders', 'target': 'baseball'}, {'input': 'Troy Aikman', 'target': 'football'}, {'input': 'Honus Wagner', 'target': 'baseball'}, {'input': 'John Elway', 'target': 'football'}, {'input': 'Emmitt Smith', 'target': 'football'}, {'input': 'Jermaine Jones', 'target': 'soccer'}, {'input': 'Eric Lindros', 'target': 'hockey'}, {'input': 'Walter Zenga', 'target': 'soccer'}, {'input': 'Gale Sayers', 'target': 'football'}, {'input': 'Stan Mikita', 'target': 'hockey'}, {'input': 'Elroy Hirsch', 'target': 'football'}, {'input': 'Efren Navarro', 'target': 'baseball'}, {'input': 'Drew Brees', 'target': 'football'}, {'input': 'Yu Darvish', 'target': 'baseball'}, {'input': 'Allen Iverson', 'target': 'basketball'}, {'input': 'DaMarcus Beasley', 'target': 'soccer'}, {'input': 'Leon Allen White', 'target': 'football'}, {'input': 'Brian Bosworth', 'target': 'football'}, {'input': 'Satchel Paige', 'target': 'baseball'}, {'input': 'Pat Tillman', 'target': 'football'}, {'input': 'Nigel de Jong', 'target': 'soccer'}, {'input': 'Warren Spahn', 'target': 'baseball'}, {'input': 'Andrea Pirlo', 'target': 'soccer'}, {'input': 'Marco Ureña', 'target': 'soccer'}, {'input': 'Carl Weathers', 'target': 'football'}, {'input': 'Danny Ainge', 'target': 'baseball'}, {'input': 'Mariano Rivera', 'target': 'baseball'}, {'input': 'Bubba Smith', 'target': 'football'}, {'input': 'Michael Umaña', 'target': 'soccer'}, {'input': 'Alexander Ovechkin', 'target': 'hockey'}, {'input': 'Jesse Hibbs', 'target': 'football'}, {'input': 'Jim Bunning', 'target': 'baseball'}, {'input': 'Larry Doby', 'target': 'baseball'}, {'input': 'Joe Namath', 'target': 'football'}, {'input': 'Roy Campanella', 'target': 'baseball'}, {'input': 'Marcos Mondaini', 'target': 'soccer'}, {'input': 'Julius Erving', 'target': 'basketball'}, {'input': 'Stan Musial', 'target': 'baseball'}, {'input': 'Dick Butkus', 'target': 'football'}], (5, 1): [{'input': 'Edgaras Jankauskas', 'target': 'soccer'}, {'input': 'Gonzalo Higuaín', 'target': 'soccer'}, {'input': 'Jaromír Jágr', 'target': 'hockey'}, {'input': 'Lutz Pfannenstiel', 'target': 'soccer'}, {'input': "Shaquille O'Neal", 'target': 'basketball'}, {'input': 'Jarome Iginla', 'target': 'hockey'}, {'input': 'Cuauhtémoc Blanco', 'target': 'soccer'}, {'input': 'Alan Ball, Jr.', 'target': 'soccer'}, {'input': 'Bronko Nagurski', 'target': 'football'}, {'input': 'Júlio Baptista', 'target': 'soccer'}, {'input': 'Emanuel Pogatetz', 'target': 'soccer'}, {'input': 'Jozy Altidore', 'target': 'soccer'}, {'input': 'Jiří Šlégr', 'target': 'hockey'}, {'input': 'Frédéric Piquionne', 'target': 'soccer'}], (6, 1): [{'input': 'Joe Garagiola Sr.', 'target': 'baseball'}, {'input': "Raïs M'Bolhi", 'target': 'soccer'}, {'input': 'Šarūnas Jasikevičius', 'target': 'basketball'}, {'input': 'Viacheslav Fetisov', 'target': 'hockey'}, {'input': 'Teemu Sälännä', 'target': 'hockey'}, {'input': 'Florent Sinama Pongolle', 'target': 'soccer'}, {'input': 'Olumide Oyedeji', 'target': 'basketball'}, {'input': 'Egidio Arévalo Rios', 'target': 'soccer'}], (7, 1): [{'input': 'Metta Sandiford-Artest', 'target': 'basketball'}]},
            "in_beginning": {(3, 1): [{'input': 'Alessandro Nesta', 'target': 'soccer'}, {'input': 'Mickey Mantle', 'target': 'baseball'}, {'input': 'Maurice Richard', 'target': 'hockey'}, {'input': 'Pete Rose', 'target': 'baseball'}, {'input': 'Kevin Love', 'target': 'basketball'}, {'input': 'Bob Gibson', 'target': 'baseball'}, {'input': 'Jack Kemp', 'target': 'football'}, {'input': 'Bart Starr', 'target': 'football'}, {'input': 'Byron White', 'target': 'football'}, {'input': 'Phil Jackson', 'target': 'basketball'}, {'input': 'Bobby Hull', 'target': 'hockey'}, {'input': 'Kaká', 'target': 'soccer'}, {'input': 'Cy Young', 'target': 'baseball'}, {'input': 'Dean Cain', 'target': 'football'}, {'input': 'Tim Duncan', 'target': 'basketball'}, {'input': 'Claudio Reyna', 'target': 'soccer'}, {'input': 'Tom Brady', 'target': 'football'}, {'input': 'Ernie Banks', 'target': 'baseball'}, {'input': 'Jerry West', 'target': 'basketball'}, {'input': 'Ty Cobb', 'target': 'baseball'}, {'input': 'Bill Goldberg', 'target': 'football'}, {'input': 'Terry Bradshaw', 'target': 'football'}, {'input': 'David Beckham', 'target': 'soccer'}, {'input': 'Elton Brand', 'target': 'basketball'}, {'input': 'Fred Williamson', 'target': 'football'}, {'input': 'Red Grange', 'target': 'football'}, {'input': 'Yao Ming', 'target': 'basketball'}, {'input': 'Bo Jackson', 'target': 'baseball'}, {'input': 'Patrick Ewing', 'target': 'basketball'}, {'input': 'Nani', 'target': 'soccer'}, {'input': 'Abel Xavier', 'target': 'soccer'}, {'input': 'Kevin Garnett', 'target': 'basketball'}, {'input': 'Matteo Ferrari', 'target': 'soccer'}, {'input': 'Babe Ruth', 'target': 'baseball'}, {'input': 'Kobe Bryant', 'target': 'basketball'}, {'input': 'Roman Reigns', 'target': 'football'}, {'input': 'David Robinson', 'target': 'basketball'}, {'input': 'Corey Perry', 'target': 'hockey'}, {'input': 'Rick Fox', 'target': 'basketball'}, {'input': 'Tim Cahill', 'target': 'soccer'}, {'input': 'Nelson Valdez', 'target': 'soccer'}, {'input': 'Frank Lampard', 'target': 'soccer'}, {'input': 'Billy Sunday', 'target': 'baseball'}, {'input': 'Terry Crews', 'target': 'football'}, {'input': 'Mauricio Wright', 'target': 'soccer'}, {'input': 'Wayne Rooney', 'target': 'soccer'}, {'input': 'Cam Newton', 'target': 'football'}, {'input': 'Bernie Casey', 'target': 'football'}, {'input': 'Alex Rodriguez', 'target': 'baseball'}, {'input': 'Walter Payton', 'target': 'football'}, {'input': 'Bill Russell', 'target': 'basketball'}, {'input': 'Arne Friedrich', 'target': 'soccer'}, {'input': 'Brett Favre', 'target': 'football'}, {'input': 'Bill Bradley', 'target': 'basketball'}, {'input': 'Roberto Clemente', 'target': 'baseball'}, {'input': 'Karl Malone', 'target': 'basketball'}, {'input': 'Gary Carter', 'target': 'baseball'}, {'input': 'Roger Maris', 'target': 'baseball'}, {'input': 'Robbie Rogers', 'target': 'soccer'}, {'input': 'Chris Paul', 'target': 'basketball'}, {'input': 'Javier Hernández', 'target': 'soccer'}, {'input': 'Jim Thorpe', 'target': 'baseball'}, {'input': 'Moe Berg', 'target': 'baseball'}, {'input': 'Len Ford', 'target': 'football'}, {'input': 'Gary Payton', 'target': 'basketball'}, {'input': 'Rick Barry', 'target': 'basketball'}, {'input': 'Moses Malone', 'target': 'basketball'}, {'input': 'Steven Gerrard', 'target': 'soccer'}, {'input': 'Steve Nash', 'target': 'basketball'}, {'input': 'David Carney', 'target': 'soccer'}, {'input': 'Ernie Davis', 'target': 'football'}, {'input': 'Chuck Connors', 'target': 'baseball'}, {'input': 'Brock Lesnar', 'target': 'football'}, {'input': 'Tom Harmon', 'target': 'football'}, {'input': 'Tony Parker', 'target': 'basketball'}, {'input': 'Carlos Arroyo', 'target': 'basketball'}, {'input': 'Phil Esposito', 'target': 'hockey'}], (5, 1): [{'input': 'Ilya Kovalchuk', 'target': 'hockey'}, {'input': 'Willie Stargell', 'target': 'baseball'}, {'input': 'Jermain Defoe', 'target': 'soccer'}, {'input': 'Torsten Frings', 'target': 'soccer'}, {'input': 'Jari Kurri', 'target': 'hockey'}, {'input': 'Marco Di Vaio', 'target': 'soccer'}, {'input': 'Dirk Nowitzki', 'target': 'basketball'}, {'input': 'Lothar Matthäus', 'target': 'soccer'}, {'input': 'Dražen Petrović', 'target': 'basketball'}, {'input': 'Antonio Nocerino', 'target': 'soccer'}, {'input': 'John Matuszak', 'target': 'football'}, {'input': 'Joe DiMaggio', 'target': 'baseball'}, {'input': 'Ichiro Suzuki', 'target': 'baseball'}, {'input': 'Ivan Hlinka', 'target': 'hockey'}, {'input': 'Júlio César', 'target': 'soccer'}, {'input': 'Colin Kaepernick', 'target': 'football'}, {'input': 'Mikaël Silvestre', 'target': 'soccer'}, {'input': 'Freddie Ljungberg', 'target': 'soccer'}, {'input': 'Zdeno Chára', 'target': 'hockey'}, {'input': 'John Olerud', 'target': 'baseball'}, {'input': 'Benny Feilhaber', 'target': 'soccer'}, {'input': 'Raimo Helminen', 'target': 'hockey'}, {'input': 'Obafemi Martins', 'target': 'soccer'}, {'input': 'Scottie Pippen', 'target': 'basketball'}, {'input': 'Igor Larionov', 'target': 'hockey'}, {'input': 'Bashkim Kadrii', 'target': 'soccer'}, {'input': 'Pavel Datsyuk', 'target': 'hockey'}, {'input': 'Jean Béliveau', 'target': 'hockey'}, {'input': 'Siem de Jong', 'target': 'soccer'}, {'input': 'Frank Mahovlich', 'target': 'hockey'}, {'input': 'Howie Morenz', 'target': 'hockey'}, {'input': 'Sergei Fedorov', 'target': 'hockey'}, {'input': 'Steve Yzerman', 'target': 'hockey'}, {'input': 'Pavol Demitra', 'target': 'hockey'}, {'input': 'Carlos Beltrán', 'target': 'baseball'}, {'input': 'Kasey Keller', 'target': 'soccer'}, {'input': 'Jürgen Locadia', 'target': 'soccer'}, {'input': 'Marián Hossa', 'target': 'hockey'}], (4, 1): [{'input': 'Boris Diaw', 'target': 'basketball'}, {'input': 'Dennis Rodman', 'target': 'basketball'}, {'input': 'Kenny Lofton', 'target': 'baseball'}, {'input': 'Luis Scola', 'target': 'basketball'}, {'input': 'Wilt Chamberlain', 'target': 'basketball'}, {'input': 'Alex Karras', 'target': 'football'}, {'input': 'Kendall Waston', 'target': 'soccer'}, {'input': 'Don Shula', 'target': 'football'}, {'input': 'Casey Stengel', 'target': 'baseball'}, {'input': 'Paul Kariya', 'target': 'hockey'}, {'input': 'Woody Strode', 'target': 'football'}, {'input': 'Jacques Plante', 'target': 'hockey'}, {'input': 'Tony Dungy', 'target': 'football'}, {'input': 'Deion Sanders', 'target': 'baseball'}, {'input': 'Troy Aikman', 'target': 'football'}, {'input': 'Honus Wagner', 'target': 'baseball'}, {'input': 'John Elway', 'target': 'football'}, {'input': 'Emmitt Smith', 'target': 'football'}, {'input': 'Jermaine Jones', 'target': 'soccer'}, {'input': 'Eric Lindros', 'target': 'hockey'}, {'input': 'Walter Zenga', 'target': 'soccer'}, {'input': 'Gale Sayers', 'target': 'football'}, {'input': 'Stan Mikita', 'target': 'hockey'}, {'input': 'Elroy Hirsch', 'target': 'football'}, {'input': 'Efren Navarro', 'target': 'baseball'}, {'input': 'Drew Brees', 'target': 'football'}, {'input': 'Yu Darvish', 'target': 'baseball'}, {'input': 'Allen Iverson', 'target': 'basketball'}, {'input': 'DaMarcus Beasley', 'target': 'soccer'}, {'input': 'Leon Allen White', 'target': 'football'}, {'input': 'Brian Bosworth', 'target': 'football'}, {'input': 'Satchel Paige', 'target': 'baseball'}, {'input': 'Pat Tillman', 'target': 'football'}, {'input': 'Nigel de Jong', 'target': 'soccer'}, {'input': 'Warren Spahn', 'target': 'baseball'}, {'input': 'Andrea Pirlo', 'target': 'soccer'}, {'input': 'Marco Ureña', 'target': 'soccer'}, {'input': 'Carl Weathers', 'target': 'football'}, {'input': 'Danny Ainge', 'target': 'baseball'}, {'input': 'Mariano Rivera', 'target': 'baseball'}, {'input': 'Bubba Smith', 'target': 'football'}, {'input': 'Michael Umaña', 'target': 'soccer'}, {'input': 'Alexander Ovechkin', 'target': 'hockey'}, {'input': 'Jesse Hibbs', 'target': 'football'}, {'input': 'Jim Bunning', 'target': 'baseball'}, {'input': 'Larry Doby', 'target': 'baseball'}, {'input': 'Joe Namath', 'target': 'football'}, {'input': 'Roy Campanella', 'target': 'baseball'}, {'input': 'Marcos Mondaini', 'target': 'soccer'}, {'input': 'Julius Erving', 'target': 'basketball'}, {'input': 'Stan Musial', 'target': 'baseball'}, {'input': 'Dick Butkus', 'target': 'football'}], (6, 1): [{'input': 'Edgaras Jankauskas', 'target': 'soccer'}, {'input': 'Gonzalo Higuaín', 'target': 'soccer'}, {'input': 'Jaromír Jágr', 'target': 'hockey'}, {'input': 'Lutz Pfannenstiel', 'target': 'soccer'}, {'input': "Shaquille O'Neal", 'target': 'basketball'}, {'input': 'Jarome Iginla', 'target': 'hockey'}, {'input': 'Cuauhtémoc Blanco', 'target': 'soccer'}, {'input': 'Alan Ball, Jr.', 'target': 'soccer'}, {'input': 'Bronko Nagurski', 'target': 'football'}, {'input': 'Júlio Baptista', 'target': 'soccer'}, {'input': 'Emanuel Pogatetz', 'target': 'soccer'}, {'input': 'Jozy Altidore', 'target': 'soccer'}, {'input': 'Jiří Šlégr', 'target': 'hockey'}, {'input': 'Frédéric Piquionne', 'target': 'soccer'}], (7, 1): [{'input': 'Joe Garagiola Sr.', 'target': 'baseball'}, {'input': "Raïs M'Bolhi", 'target': 'soccer'}, {'input': 'Šarūnas Jasikevičius', 'target': 'basketball'}, {'input': 'Viacheslav Fetisov', 'target': 'hockey'}, {'input': 'Teemu Sälännä', 'target': 'hockey'}, {'input': 'Florent Sinama Pongolle', 'target': 'soccer'}, {'input': 'Olumide Oyedeji', 'target': 'basketball'}, {'input': 'Egidio Arévalo Rios', 'target': 'soccer'}], (8, 1): [{'input': 'Metta Sandiford-Artest', 'target': 'basketball'}]},
        },
        "present_past": {
            "regular": {(1, 1): [{'input': 'resolve', 'target': 'resolved'}, {'input': 'divide', 'target': 'divided'}, {'input': 'guide', 'target': 'guided'}, {'input': 'solve', 'target': 'solved'}, {'input': 'process', 'target': 'processed'}], (1, 2): [{'input': 'yield', 'target': 'yielded'}, {'input': 'exist', 'target': 'existed'}, {'input': 'model', 'target': 'modelled'}, {'input': 'gain', 'target': 'gained'}, {'input': 'assess', 'target': 'assessed'}], (2, 1): [{'input': 'utilize', 'target': 'utilized'}, {'input': 'possess', 'target': 'possessed'}, {'input': 'satisfy', 'target': 'satisfied'}, {'input': 'qualify', 'target': 'qualified'}, {'input': 'educate', 'target': 'educated'}], (2, 2): [{'input': 'instruct', 'target': 'instructed'}, {'input': 'investigate', 'target': 'investigated'}, {'input': 'survive', 'target': 'survived'}, {'input': 'clarify', 'target': 'clarified'}, {'input': 'enlarge', 'target': 'enlarged'}], (2, 3): [{'input': 'pursue', 'target': 'pursued'}, {'input': 'facilitate', 'target': 'facilitated'}], (3, 3): [{'input': 'persuade', 'target': 'persuaded'}], (1, 3): [{'input': 'breathe', 'target': 'breathed'}]},
            "in_beginning": {(2, 1): [{'input': 'resolve', 'target': 'resolved'}, {'input': 'divide', 'target': 'divided'}, {'input': 'guide', 'target': 'guided'}, {'input': 'solve', 'target': 'solved'}, {'input': 'process', 'target': 'processed'}], (2, 2): [{'input': 'yield', 'target': 'yielded'}, {'input': 'exist', 'target': 'existed'}, {'input': 'model', 'target': 'modelled'}, {'input': 'gain', 'target': 'gained'}, {'input': 'assess', 'target': 'assessed'}], (3, 1): [{'input': 'utilize', 'target': 'utilized'}, {'input': 'possess', 'target': 'possessed'}, {'input': 'satisfy', 'target': 'satisfied'}, {'input': 'qualify', 'target': 'qualified'}, {'input': 'educate', 'target': 'educated'}], (3, 2): [{'input': 'instruct', 'target': 'instructed'}, {'input': 'investigate', 'target': 'investigated'}, {'input': 'survive', 'target': 'survived'}, {'input': 'clarify', 'target': 'clarified'}, {'input': 'enlarge', 'target': 'enlarged'}], (3, 3): [{'input': 'pursue', 'target': 'pursued'}, {'input': 'facilitate', 'target': 'facilitated'}], (4, 3): [{'input': 'persuade', 'target': 'persuaded'}], (2, 3): [{'input': 'breathe', 'target': 'breathed'}]},
        },
    }

    def get_docs_in_templates(self, limit = None, corrupted = False, tokenizer = None) -> List[Dict[str, str]]:
        eval_docs, fewshot_docs = self.get_docs(limit)
        contexts = [self.fewshot_space.join([self.doc_to_text(doc) + self.sep + self.doc_to_target(doc) for doc in docs]) + self.fewshot_space + self.doc_to_text(eval_doc) + self.sep
                    for docs, eval_doc in zip(fewshot_docs, eval_docs)]
        targets = [self.doc_to_target(eval_doc) for eval_doc in eval_docs]
        if corrupted:
            corrupted_fewshot_docs = []
            corrupt_query = []
            corrupt_query_within_input_space = []
            special_token = "<special-token-we-will-never-encounter-in-a-dataset>"
            tokenizer.add_tokens([special_token])
            np.random.seed(self.fewshot_seed)
            QUERY_CORR = self.WORDS_PER_NUM_OF_TOKENS
            QUERY_CORR_IN_BEGINNING = self.WORDS_PER_NUM_OF_TOKENS_IN_BEGINNING
            TARGET_CORR = self.WORDS_PER_NUM_OF_TOKENS
            for context in fewshot_docs:
                num_tokens_per_query = [len(tokenizer(special_token + item["input"], add_special_tokens=False, return_tensors="pt")["input_ids"][0]) - 1 for item in context]
                num_tokens_per_query[0] = len(tokenizer(context[0]["input"], add_special_tokens=True, return_tensors="pt")["input_ids"][0])
                num_tokens_per_target = [len(tokenizer(special_token + item["target"], add_special_tokens=False, return_tensors="pt")["input_ids"][0]) - 1 for item in context]
                corrupted_queries = [[np.random.choice(QUERY_CORR[num_tok]) if j != 0 else np.random.choice(QUERY_CORR_IN_BEGINNING[num_tok])
                                        for j, num_tok in enumerate(num_tokens_per_query)]
                                        for i in range(self.NUM_CORRUPTIONS + 1)]
                corrupted_targets = [[np.random.choice(TARGET_CORR[num_tok]) for num_tok in num_tokens_per_target] for i in range(self.NUM_CORRUPTIONS + 1)]
                corrupted_items = [[np.random.choice(self.X_Y_CORRUPTION_PER_DATASET[self.dataset_name]["regular" if j != 0 else "in_beginning"][(t1, t2)])
                                    for j, (t1, t2) in enumerate(zip(num_tokens_per_query, num_tokens_per_target))]
                                    for i in range(self.NUM_CORRUPTIONS)]
                corrupted_fewshot_docs_per_fewshot = []
                corrupt_query_per_fewshot = []
                corrupt_query_within_input_space_per_fewshot = []
                new_items = [{"input" : corrupted_queries[-1][i], "target": corrupted_targets[-1][i]} for i, item in enumerate(context)]
                corrupted_fewshot_docs_per_fewshot.append(new_items)
                corrupt_query_per_fewshot.append(True)
                corrupt_query_within_input_space_per_fewshot.append(False)
                if self.corruption == "x_in_all_fewshots_agreed":
                    new_items = [{"input" : corrupted_queries[-1][i], "target": item["target"]} for i, item in enumerate(context)]
                    corrupted_fewshot_docs_per_fewshot.append(new_items)
                    corrupt_query_per_fewshot.append(False)
                    corrupt_query_within_input_space_per_fewshot.append(False)
                elif self.corruption == "y_in_all_fewshots_agreed":
                    new_items = [{"input" : item["input"], "target": corrupted_targets[-1][i]} for i, item in enumerate(context)]
                    corrupted_fewshot_docs_per_fewshot.append(new_items)
                    corrupt_query_per_fewshot.append(False)
                    corrupt_query_within_input_space_per_fewshot.append(False)
                elif self.corruption == "x_y_in_all_fewshots_agreed":
                    new_items = [corrupted_items[-1][i] for i, item in enumerate(context)]
                    corrupted_fewshot_docs_per_fewshot.append(new_items)
                    corrupt_query_per_fewshot.append(False)
                    corrupt_query_within_input_space_per_fewshot.append(False)
                elif self.corruption == "x_within_input_space_in_all_fewshots_but_current_agreed":
                    for num_fewshot in range(self.fewshot):
                        new_items = [{"input" : corrupted_items[-1][i]["input"] if i < num_fewshot else item["input"],
                                      "target": item["target"]} for i, item in enumerate(context)]
                        corrupted_fewshot_docs_per_fewshot.append(new_items)
                        corrupt_query_per_fewshot.append(False)
                        corrupt_query_within_input_space_per_fewshot.append(False)
                elif self.corruption == "x_in_all_fewshots_but_current_agreed":
                    for num_fewshot in range(self.fewshot):
                        new_items = [{"input" : corrupted_queries[-1][i] if i < num_fewshot else item["input"],
                                      "target": item["target"]} for i, item in enumerate(context)]
                        corrupted_fewshot_docs_per_fewshot.append(new_items)
                        corrupt_query_per_fewshot.append(False)
                        corrupt_query_within_input_space_per_fewshot.append(False)
                elif self.corruption == "y_in_all_fewshots_but_current_agreed":
                    for num_fewshot in range(self.fewshot):
                        new_items = [{"input" : item["input"],
                                      "target": corrupted_targets[-1][i] if i < num_fewshot else item["target"]} for i, item in enumerate(context)]
                        corrupted_fewshot_docs_per_fewshot.append(new_items)
                        corrupt_query_per_fewshot.append(False)
                        corrupt_query_within_input_space_per_fewshot.append(False)
                elif self.corruption == "y_within_input_space_in_all_fewshots_but_current_agreed":
                    for num_fewshot in range(self.fewshot):
                        new_items = [{"input" : item["input"],
                                      "target": corrupted_items[-1][i]["target"] if i < num_fewshot else item["target"]} for i, item in enumerate(context)]
                        corrupted_fewshot_docs_per_fewshot.append(new_items)
                        corrupt_query_per_fewshot.append(False)
                        corrupt_query_within_input_space_per_fewshot.append(False)
                elif self.corruption == "y_in_all_fewshots_but_current_and_previous_agreed":
                    for num_fewshot in range(self.fewshot):
                        new_items = [{"input" : item["input"],
                                      "target": corrupted_targets[-1][i] if i < num_fewshot - 1 else item["target"]} for i, item in enumerate(context)]
                        corrupted_fewshot_docs_per_fewshot.append(new_items)
                        corrupt_query_per_fewshot.append(False)
                        corrupt_query_within_input_space_per_fewshot.append(False)
                elif self.corruption == "y_within_input_space_in_all_fewshots_but_current_and_previous_agreed":
                    for num_fewshot in range(self.fewshot):
                        new_items = [{"input" : item["input"],
                                      "target": corrupted_items[-1][i]["target"] if i < num_fewshot - 1 else item["target"]} for i, item in enumerate(context)]
                        corrupted_fewshot_docs_per_fewshot.append(new_items)
                        corrupt_query_per_fewshot.append(False)
                        corrupt_query_within_input_space_per_fewshot.append(False)
                elif self.corruption == "y_in_all_fewshots_but_current_and_2_previous_agreed":
                    for num_fewshot in range(self.fewshot):
                        new_items = [{"input" : item["input"],
                                      "target": corrupted_targets[-1][i] if i < num_fewshot - 2 else item["target"]} for i, item in enumerate(context)]
                        corrupted_fewshot_docs_per_fewshot.append(new_items)
                        corrupt_query_per_fewshot.append(False)
                        corrupt_query_within_input_space_per_fewshot.append(False)
                elif self.corruption == "y_within_input_space_in_all_fewshots_but_current_and_2_previous_agreed":
                    for num_fewshot in range(self.fewshot):
                        new_items = [{"input" : item["input"],
                                      "target": corrupted_items[-1][i]["target"] if i < num_fewshot - 2 else item["target"]} for i, item in enumerate(context)]
                        corrupted_fewshot_docs_per_fewshot.append(new_items)
                        corrupt_query_per_fewshot.append(False)
                        corrupt_query_within_input_space_per_fewshot.append(False)
                elif self.corruption == "x_within_input_space_in_all_fewshots_agreed":
                    new_items = [{"input" : corrupted_items[-1][i]["input"], "target": item["target"]} for i, item in enumerate(context)]
                    corrupted_fewshot_docs_per_fewshot.append(new_items)
                    corrupt_query_per_fewshot.append(False)
                    corrupt_query_within_input_space_per_fewshot.append(False)
                elif self.corruption == "y_within_input_space_in_all_fewshots_agreed":
                    new_items = [{"input" : item["input"], "target": corrupted_items[-1][i]["target"]} for i, item in enumerate(context)]
                    corrupted_fewshot_docs_per_fewshot.append(new_items)
                    corrupt_query_per_fewshot.append(False)
                    corrupt_query_within_input_space_per_fewshot.append(False)
                elif self.corruption == "full_corruption_in_all_fewshots_agreed":
                    new_items = [{"input" : corrupted_queries[-1][i], "target": corrupted_targets[-1][i]} for i, item in enumerate(context)]
                    corrupted_fewshot_docs_per_fewshot.append(new_items)
                    corrupt_query_per_fewshot.append(False)
                    corrupt_query_within_input_space_per_fewshot.append(False)
                else:
                    assert False
                corrupted_fewshot_docs.append(corrupted_fewshot_docs_per_fewshot)
                corrupt_query.append(corrupt_query_per_fewshot)
                corrupt_query_within_input_space.append(corrupt_query_within_input_space_per_fewshot)

            num_tokens_per_query = [len(tokenizer(special_token + item["input"], add_special_tokens=False, return_tensors="pt")["input_ids"][0]) - 1 for item in eval_docs]
            corrupted_queries = [np.random.choice(QUERY_CORR[num_tok]) for num_tok in num_tokens_per_query]
            corrupted_eval_docs = [{"input" : corrupted_queries[i], "target": item["target"]} for i, item in enumerate(eval_docs)]
            
            self.X_CORRUPTION_PER_DATASET = {
                dataset: {
                    tp: {
                        nt: [item for (nt1_, nt2_), items in self.X_Y_CORRUPTION_PER_DATASET[dataset][tp].items() for item in items if nt1_ == nt]
                        for nt in set([nt1 for (nt1, nt2) in self.X_Y_CORRUPTION_PER_DATASET[dataset][tp]])
                    }
                    for tp in self.X_Y_CORRUPTION_PER_DATASET[dataset]
                }
                for dataset in self.X_Y_CORRUPTION_PER_DATASET
            }
            corrupted_within_input_space_queries = [np.random.choice(
                self.X_CORRUPTION_PER_DATASET[self.dataset_name]["regular"][num_tok])["input"]
                if num_tok in self.X_CORRUPTION_PER_DATASET[self.dataset_name]["regular"] else 
                eval_docs[i]["input"]
                for i, num_tok in enumerate(num_tokens_per_query)]
            corrupted_within_input_space_eval_docs = [{"input" : corrupted_within_input_space_queries[i], "target": item["target"]} for i, item in enumerate(eval_docs)]

            corrupted_contexts = [
                [
                    self.fewshot_space.join([self.doc_to_text(doc) + self.sep + self.doc_to_target(doc) for doc in docs]) + self.fewshot_space + ((self.doc_to_text(corrupted_within_input_space_eval_doc) if corrupt_within_input_space_query_per_item else self.doc_to_text(corrupted_eval_doc)) if corrupt_query_per_item else self.doc_to_text(eval_doc)) + self.sep
                    for docs, (corrupt_query_per_item, corrupt_within_input_space_query_per_item) in zip(docs_fer_fewshot, zip(corrupt_query_per_fewshot, corrupt_within_input_space_query_per_fewshot))
                ]
                for (docs_fer_fewshot, eval_doc), ((corrupt_query_per_fewshot, corrupted_eval_doc), (corrupt_within_input_space_query_per_fewshot, corrupted_within_input_space_eval_doc)) in zip(zip(corrupted_fewshot_docs, eval_docs), zip(zip(corrupt_query, corrupted_eval_docs), zip(corrupt_query_within_input_space, corrupted_within_input_space_eval_docs)))]
            return [
            {
                "context": context,
                "target": target,
                "corrupted_contexts": corrupted_context,
                "corrupted_target": target,
            }
            for context, target, corrupted_context in zip(contexts, targets, corrupted_contexts)
            ]
        return [
            {
                "context": context,
                "target": target
            }
            for context, target in zip(contexts, targets)
            ]