from packaging import version
import torch
from torch import nn
from typing import Any, Dict, List, Optional, Tuple, Union
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from transformers import Seq2SeqTrainer
from .trainer import BaseTrainer
import sys
import inspect
import os
import pickle
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('t5-base')

class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, "w")

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout


class HLoss(nn.Module):
    def __init__(self, unit=1):
        super().__init__()
        self.unit = unit

    def forward(self, x):
        b = F.softmax(x, dim=-1) * F.log_softmax(x, dim=-1)
        # b = torch.mean(b, 1)
        b = b.sum() / self.unit
        return b


def layer_norm(x):
    mean = torch.mean(x)
    var = torch.square(x - mean).mean()
    return (x - mean) / torch.sqrt(var)


sys.path.append(
    "/export/home/OpenPrompt/mixture_prompt/attempt/seq2seq"
)  # todo: change to relative path

from self_train import find_majority, map_vote

if version.parse(torch.__version__) >= version.parse("1.6"):
    from torch.cuda.amp import autocast


class Seq2SeqTrainer(Seq2SeqTrainer, BaseTrainer):
    def __init__(
        self,
        train_dataset_sizes=None,
        shared=False,
        adapter_config=None,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.adapter_config = adapter_config
        self.train_dataset_sizes = train_dataset_sizes
        self.shared = shared

    def evaluate(
        self,
        eval_dataset: Optional[Dict[str, Dataset]] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
        max_length: Optional[int] = None,
        num_beams: Optional[int] = None,
    ) -> Dict[str, float]:
        # TODO: this also needs to be set per dataset
        self._max_length = max_length
        self._num_beams = num_beams
        return super().evaluate(
            eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )

    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
            labels (each being optional).
        """
        if not self.args.predict_with_generate or prediction_loss_only:
            return super().prediction_step(
                model,
                inputs,
                prediction_loss_only=prediction_loss_only,
                ignore_keys=ignore_keys,
            )

        has_labels = "labels" in inputs
        inputs = self._prepare_inputs(inputs)
        gen_kwargs = {
            "max_length": self._max_length
            if self._max_length is not None
            else self.model.config.max_length,
            "num_beams": self._num_beams
            if self._num_beams is not None
            else self.model.config.num_beams,
            "task": inputs["task"] if "task" in inputs else "all",
        }

        generated_tokens = self.model.generate(
            inputs["input_ids"], attention_mask=inputs["attention_mask"], **gen_kwargs,
        )

        # in case the batch is shorter than max length, the output should be padded
        if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
            generated_tokens = self._pad_tensors_to_max_len(
                generated_tokens, gen_kwargs["max_length"]
            )

        with torch.no_grad():
            if self.use_amp:
                with autocast():
                    outputs = model(**inputs)
            else:
                outputs = model(**inputs)
            if has_labels:
                if self.label_smoother is not None:
                    loss = (
                        self.label_smoother(outputs, inputs["labels"]).mean().detach()
                    )
                else:
                    loss = (
                        (outputs["loss"] if isinstance(outputs, dict) else outputs[0])
                        .mean()
                        .detach()
                    )
            else:
                loss = None

        if self.args.prediction_loss_only:
            return (loss, None, None)

        labels = inputs["labels"]
        if labels.shape[-1] < gen_kwargs["max_length"]:
            labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])

        return (loss, generated_tokens, labels)


class MyModel(nn.Module):
    """
    A model to weight the output logits.
    type: logit: use attention_net to project logits and then weight logits.
          uw: use uw style of attention_net to project logits and then weight logits.
          uw-input: use uw style of attention_net to project inputs and then weight logits.
    """

    def __init__(
        self,
        embed_size=None,
        vocab_size=None,
        type="uw-input",
        target_task=None,
        mapping=True,
        file_path=None,
        seed_id=0,
        neural_deal=True,
        idx_lst=None,
        attention_size_input_key=100,
        attention_size_input_query=100,
        attention_size_output=100,
        hloss=0,
        dropout=0,
        update_idx=False,
    ):
        super(MyModel, self).__init__()
        self.vocab_size = vocab_size
        self.device = "cuda"
        self.type = type
        self.target_task = target_task
        self.mapping = mapping
        self.debug = None
        self.file_path = file_path
        self.seed_id = seed_id
        self.neural_deal = neural_deal
        self.idx_lst = idx_lst  # define which pretrained soft prompts you want to use.
        self.hloss = hloss

        # update idx all the time
        if update_idx:
            with open(self.file_path, "rb") as handle:
                self.idx_lst = pickle.load(handle)[self.seed_id]
        # self.update_idx = update_idx

        if self.type == "logit":
            self.attention_net = nn.Sequential(
                nn.Linear(vocab_size, 1024),
                nn.ReLU(),
                nn.Linear(1024, 1024),
                nn.ReLU(),
                nn.Linear(1024, 256),
                nn.ReLU(),
                nn.Linear(256, 1),
            ).cuda()
        elif self.type == "uw-input-attention_at":
            if dropout:
                self.attention_net_key = nn.Sequential(
                    nn.Linear(embed_size, attention_size_input_key, bias=False),
                    nn.ReLU(),
                    nn.Dropout(0.1 * dropout),
                    nn.Linear(attention_size_input_key, attention_size_output, bias=False),
                    nn.LayerNorm(attention_size_output),
                ).cuda()

                self.attention_net_query = nn.Sequential(
                    nn.Linear(embed_size, attention_size_input_query, bias=False),
                    nn.SiLU(),
                    nn.Dropout(0.1 * dropout),
                    nn.Linear(attention_size_input_query, attention_size_output, bias=False),
                    nn.LayerNorm(attention_size_output),
                ).cuda()
            else:
                self.attention_net_key = nn.Sequential(
                    nn.Linear(embed_size, attention_size_input_key, bias=False),
                    nn.ReLU(),
                    # nn.Dropout(0.1),
                    nn.Linear(attention_size_input_key, attention_size_output, bias=False),
                    nn.LayerNorm(attention_size_output),
                ).cuda()

                self.attention_net_query = nn.Sequential(
                    nn.Linear(embed_size, attention_size_input_query, bias=False),
                    nn.SiLU(),
                    # nn.Dropout(0.1),
                    nn.Linear(attention_size_input_query, attention_size_output, bias=False),
                    nn.LayerNorm(attention_size_output),
                ).cuda()
        elif "uw" in self.type:
            if dropout:
                self.attention_net_key = nn.Sequential(
                    nn.Linear(embed_size, attention_size_input_key, bias=False),
                    nn.ReLU(),
                    nn.Dropout(0.1*dropout),
                    nn.Linear(attention_size_input_key, attention_size_output, bias=False),
                    nn.LayerNorm(attention_size_output),
                ).cuda()

                self.attention_net_query = nn.Sequential(
                    nn.Linear(vocab_size, attention_size_input_query, bias=False),
                    nn.SiLU(),
                    nn.Dropout(0.1*dropout),
                    nn.Linear(attention_size_input_query, attention_size_output, bias=False),
                    nn.LayerNorm(attention_size_output),
                ).cuda()
            else:
                self.attention_net_key = nn.Sequential(
                    nn.Linear(embed_size, attention_size_input_key, bias=False),
                    nn.ReLU(),
                    # nn.Dropout(0.1),
                    nn.Linear(attention_size_input_key, attention_size_output, bias=False),
                    nn.LayerNorm(attention_size_output),
                ).cuda()

                self.attention_net_query = nn.Sequential(
                    nn.Linear(vocab_size, attention_size_input_query, bias=False),
                    nn.SiLU(),
                    # nn.Dropout(0.1),
                    nn.Linear(attention_size_input_query, attention_size_output, bias=False),
                    nn.LayerNorm(attention_size_output),
                ).cuda()

            # self.attention = torchnlp.nn.Attention(32)
        else:
            pass

    def _mapping_logits(self, mul_logits_weight):
        # mul_logits_weight = mul_logits_weight[:, :, 0, :]  # torch.Size([32, 6, 32100])
        # map teachers logits first
        swipe_couples = None
        # replace_couples = None
        # if self.target_task in ["wnli", "superglue-multirc"]:
        #     replace_couples = [[209, 3]]

        if self.target_task in ["superglue-wsc.fixed", "wnli", "wnli_ppt", 'wsc_ppt', 'superglue-rte']:
            neural_deal = False
        else:
            neural_deal = self.neural_deal

        if self.target_task in ["superglue-wsc.fixed", "wnli"]:
            if neural_deal:
                swipe_couples = [[3, 10747], [209, 10998], [209, 59]]
            else:
                swipe_couples = [[3, 10747], [3, 204], [209, 10998], [209, 59]]

        elif self.target_task in ["wnli_ppt", "mrpc_ppt", "superglue-boolq_ppt", "superglue-multirc_ppt", "superglue-wic_ppt", "superglue-wsc.fixed_ppt"]:
            '''
            wnli_ppt -> 71 -> A -> yes -> entailment -> 1 -> 209
            wnli_ppt -> 205 -> C -> no -> 0 -> 3
            
            mrpc_ppt -> equivalent -> 1-> yes -> A -> 209 -> 71
            mrpc_ppt -> not equivalent -> 0 -> no -> C -> 3 -> 205
            
            boolq -> True ->  1 -> 209 -> A -> 71
            boolq -> False ->  0 -> 3 -> 205
            
            multirc -> 0 > False > 3 > no > C  > 205
            multirc -> 1 > True > 209 > yes > A  > 71
            '''
            swipe_couples = [[3, 205], [209, 71]]

        elif self.target_task in ["superglue-rte_ppt"]:
            '''
            rte -> 0 -> entailemnt -> yes -> A -> 71 -> 3
            rte -> 1 -> not entailemnt -> no -> C -> 205 -> 209
            '''
            swipe_couples = [[3, 71], [209, 205]]

        elif self.target_task in [
            "superglue-boolq",
            "superglue-multirc",
            "superglue-wic",
        ]:
            # swipe_couples = [[3, 10747], [3, 204], [209, 10998]]
            if neural_deal:
                swipe_couples = [[3, 10747], [209, 10998]]
            else:
                swipe_couples = [[3, 10747], [209, 204], [209, 10998]]

        elif self.target_task in ["superglue-rte"]:
            if neural_deal:
                swipe_couples = [[209, 27252]]
            else:
                swipe_couples = [[209, 27252], [209, 7163], [209, 204]]

        elif self.target_task in ["superglue-cb"]:  # 3 classes, 3; 209; 204
            swipe_couples = [[209, 27252], [204, 7163]]

        elif self.target_task in ["superglue-cb_ppt"]:  # 3 classes, 3; 209; 204
            '''
            cb -> contradiction ->  1 > 209 > no > C. 205
            cb -> entailment ->  0 > 3 > no > A. 71
            cb -> neutral -> B > 272 > 204
            '''
            swipe_couples = [[209, 205], [204, 272], [3, 71]]

        elif self.target_task in ["mrpc"]:
            swipe_couples = [[209, 7072], [3, 59]]

        if swipe_couples:
            for idx in range(mul_logits_weight.size(1)):  # 6 teachers
                print(idx, '>>>', mul_logits_weight[:, idx, :].argmax(-1))

                for num in range(mul_logits_weight.size(0)):  # batch size

                    # todo: swipe the number in mul_logits_weight,
                    # for replace in replace_couples:
                    #     left, right = mul_logits_weight[num, idx, replace[0]].clone(), mul_logits_weight[num, idx, replace[1]].clone()
                    #     mul_logits_weight[num, idx, replace[0]] = right
                    #     mul_logits_weight[num, idx, replace[1]] = left

                    for swipe in swipe_couples:  # clone is important!
                        max_num, min_num = max(
                                mul_logits_weight[num, idx, swipe[0]],
                                mul_logits_weight[num, idx, swipe[1]],
                            ).clone(), min(
                                mul_logits_weight[num, idx, swipe[0]],
                                mul_logits_weight[num, idx, swipe[1]],
                            ).clone()

                        mul_logits_weight[num, idx, swipe[0]] = max_num
                        mul_logits_weight[num, idx, swipe[1]] = min_num

                print('After>>>', mul_logits_weight[:, idx, :].argmax(-1))

                print(idx,'>>>', mul_logits_weight[:, idx, 3], mul_logits_weight[:, idx, 209],  mul_logits_weight[:, idx, 204])
            # layer_norm = nn.LayerNorm(mul_logits_weight.size(-1)).cuda()
            # print('>>>>>>', layer_norm(mul_logits_weight).size())
            # print(idx, mul_logits_weight[0, :, 3], mul_logits_weight[0, :, 209])
        return mul_logits_weight

    def _mapping(self, mul_logits_weight, weights_attn):
        logits = torch.einsum(
            "bp, bpd -> bd", weights_attn, mul_logits_weight
        )  # torch.Size([96, 32100])
        # logits = F.softmax(logits, -1)  # ([32, 32100])
        return logits

    def _mapping_gen(self, logits):
        """
        map first token to the label
        """
        if self.neural_deal:
            res = []
            if self.target_task in ["wnli", "superglue-wsc.fixed"]:
                for idx in range(logits.size(0)):
                    next_word = logits[idx].argmax(dim=-1)
                    if int(next_word) in [209, 10998, 597, 27252]:  # , 204, 7163]:
                        res.append([209, 1])
                    elif int(next_word) in [3, 10747, 204, 7163]:
                        res.append([3, 632, 1])
                    else:
                        if logits[idx][3] > logits[idx][209]:
                            res.append([3, 632, 1])
                        else:
                            res.append([209, 1])
            elif self.target_task in ["wnli_ppt", "mrpc_ppt", "superglue-boolq_ppt", "superglue-multirc_ppt", "superglue-wic_ppt", "superglue-wsc.fixed_ppt"]:
                '''
                wnli_ppt -> 71 -> A -> yes -> entailmentt -> 1 -> 209
                wnli_ppt -> 205 -> C -> no -> 0 -> 3
                mrpc_ppt -> equivalent -> 1-> yes -> A -> 209 -> 71
                mrpc_ppt -> not equivalent -> 0 -> no -> C -> 3 -> 205
                boolq -> True ->  1 -> 209 -> A -> 71
                boolq -> False ->  0 -> 3 -> 205
                multirc -> 0 > False > 3 > no > C  > 205
                multirc -> 1 > True > 209 > yes > A  > 71
                '''
                for idx in range(logits.size(0)):
                    next_word = logits[idx].argmax(dim=-1)
                    if int(next_word) in [209, 71]:  # , 204, 7163]:
                        res.append([209, 1])
                    elif int(next_word) in [3, 205]:
                        res.append([3, 632, 1])
                    else:
                        if logits[idx][205] > logits[idx][71]:
                            res.append([3, 632, 1])
                        else:
                            res.append([209, 1])

            elif self.target_task in ["superglue-rte_ppt"]:
                '''
                rte -> 0 -> entailemnt -> yes -> A -> 71 -> 3
                rte -> 1 -> not entailemnt -> no -> C -> 205 -> 209
                '''
                for idx in range(logits.size(0)):
                    next_word = logits[idx].argmax(dim=-1)
                    if int(next_word) in [209, 205]:  # , 204, 7163]:
                        res.append([209, 1])
                    elif int(next_word) in [3, 71]:
                        res.append([3, 632, 1])
                    else:
                        if logits[idx][71] > logits[idx][205]:
                            res.append([3, 632, 1])
                        else:
                            res.append([209, 1])

            elif self.target_task in [
                "superglue-boolq",
                "superglue-rte",
                "superglue-multirc",
                "superglue-wic",
            ]:
                for idx in range(logits.size(0)):
                    next_word = logits[idx].argmax(dim=-1)
                    if int(next_word) in [209, 10998, 597, 27252, 204, 7163]:
                        res.append([209, 1])
                    elif int(next_word) in [3, 10747]:  # , 204, 7163]:
                        res.append([3, 632, 1])
                    else:
                        if logits[idx][3] > logits[idx][209]:
                            res.append([3, 632, 1])
                        else:
                            res.append([209, 1])
            elif self.target_task in ["superglue-cb"]:
                for idx in range(logits.size(0)):
                    next_word = logits[idx].argmax(dim=-1)
                    if int(next_word) in [209, 10998, 597, 27252]:
                        res.append([209, 1])
                    elif int(next_word) in [3]:
                        res.append([3, 632, 1])
                    else:
                        res.append([204, 1])
            elif self.target_task in ["superglue-cb_ppt"]:
                '''
                cb -> contradiction ->  1 > 209 > no > C. 205
                cb -> entailment ->  0 > 3 > no > A. 71
                cb -> neutral -> B > 272 > 204
                '''
                for idx in range(logits.size(0)):
                    next_word = logits[idx].argmax(dim=-1)
                    if int(next_word) in [209, 205]:
                        res.append([209, 1])
                    elif int(next_word) in [3, 71]:
                        res.append([3, 632, 1])
                    else:
                        res.append([204, 1])

            elif self.target_task in ["mrpc"]:
                for idx in range(logits.size(0)):
                    next_word = logits[idx].argmax(dim=-1)
                    if int(next_word) in [209, 7072]:
                        res.append([209, 1])
                    elif int(next_word) in [3, 59]:
                        res.append([3, 632, 1])
                    else:
                        if logits[idx][3] > logits[idx][209]:
                            res.append([3, 632, 1])
                        else:
                            res.append([209, 1])
            else:
                for idx in range(logits.size(0)):
                    next_word = logits[idx].argmax(dim=-1)
                    if int(next_word) in [209, 7072]:
                        res.append([209, 1])
                    elif int(next_word) in [3, 59]:
                        res.append([3, 632, 1])
                    else:
                        if logits[idx][3] > logits[idx][209]:
                            res.append([3, 632, 1])
                        else:
                            res.append([209, 1])

        else:
            next_word = logits.argmax(dim=-1)
            res = []
            if self.target_task in ["wnli", "superglue-wsc.fixed"]:
                for idx in range(next_word.size(0)):
                    # print('next_word[idx]', int(next_word[idx]))
                    if int(next_word[idx]) in [209, 10998, 597, 27252]:  # , 204, 7163]:
                        res.append([209, 1])
                    elif int(next_word[idx]) in [3, 10747, 204, 7163]:
                        res.append([3, 632, 1])
                    else:
                        # print(next_word[idx])
                        res.append([204, 1])

            elif self.target_task in ["wnli_ppt", "mrpc_ppt", "superglue-boolq_ppt", "superglue-multirc_ppt", "superglue-wic_ppt", "superglue-wsc.fixed_ppt"]:
                '''
                wnli_ppt -> 71 -> A -> yes -> entailmentt -> 1 -> 209
                wnli_ppt -> 205 -> C -> no -> 0 -> 3
                mrpc_ppt -> equivalent -> 1-> yes -> A -> 209 -> 71
                mrpc_ppt -> not equivalent -> 0 -> no -> C -> 3 -> 205
                '''
                for idx in range(logits.size(0)):
                    next_word = logits[idx].argmax(dim=-1)
                    if int(next_word) in [209, 71]:  # , 204, 7163]:
                        res.append([209, 1])
                    elif int(next_word) in [3, 205]:
                        res.append([3, 632, 1])
                    else:
                        res.append([204, 1])

            elif self.target_task in ["superglue-rte_ppt"]:
                '''
                rte -> 0 -> entailemnt -> yes -> A -> 71 -> 3
                rte -> 1 -> not entailemnt -> no -> C -> 205 -> 209
                '''
                for idx in range(logits.size(0)):
                    next_word = logits[idx].argmax(dim=-1)
                    if int(next_word) in [209, 205]:  # , 204, 7163]:
                        res.append([209, 1])
                    elif int(next_word) in [3, 71]:
                        res.append([3, 632, 1])
                    else:
                        res.append([204, 1])

            elif self.target_task in [
                "superglue-boolq",
                "superglue-rte",
                "superglue-multirc",
                "superglue-wic",
            ]:
                for idx in range(next_word.size(0)):
                    # print('next_word[idx]', int(next_word[idx]))
                    if int(next_word[idx]) in [209, 10998, 597, 27252, 204, 7163]:
                        res.append([209, 1])
                    elif int(next_word[idx]) in [3, 10747]:  # , 204, 7163]:
                        res.append([3, 632, 1])
                    else:
                        # print(next_word[idx])
                        res.append([204, 1])

            elif self.target_task in ["superglue-cb"]:
                for idx in range(next_word.size(0)):
                    # print('next_word[idx]', int(next_word[idx]))
                    if int(next_word[idx]) in [209, 10998, 597, 27252]:
                        res.append([209, 1])
                    elif int(next_word[idx]) in [3]:
                        res.append([3, 632, 1])
                    else:
                        # print(next_word[idx])
                        res.append([204, 1])

            elif self.target_task in ["superglue-cb_ppt"]:
                '''
                cb -> contradiction ->  1 > 209 > no > C. 205
                cb -> entailment ->  0 > 3 > no > A. 71
                cb -> neutral -> B > 272 > 204
                '''
                for idx in range(logits.size(0)):
                    next_word = logits[idx].argmax(dim=-1)
                    if int(next_word) in [209, 205]:
                        res.append([209, 1])
                    elif int(next_word) in [3, 71]:
                        res.append([3, 632, 1])
                    else:
                        res.append([204, 1])

            elif self.target_task in ["mrpc"]:
                for idx in range(next_word.size(0)):

                    if int(next_word[idx]) in [209, 7072]:
                        res.append([209, 1])
                    elif int(next_word[idx]) in [3, 59]:
                        res.append([3, 632, 1])
                    else:

                        res.append([204, 1])
            else:
                print("!!!!!!!>>>>>>>!!!!!<<<<<<<%%%%%%%%")
        return res

    def generate(self, t5_model, inputs, gen_kwargs, mapping=True):
        # with HiddenPrints():
        done = 1
        done_update = set()  # add idx of finished sample
        # print('generate --->')
        _, _, weights_attn, mul_logits_weight, _ = self.get_weights(
            t5_model, inputs, predict=True
        )  # torch.Size([96, 6, 32100])

        decoder_start_token_id = t5_model.config.decoder_start_token_id
        decoder_input_ids = (
            torch.ones((inputs["input_ids"].size(0), 1)).long() * decoder_start_token_id
        )
        count = 0  # move the logits we need to weight.
        res = [[] for i in range(inputs["input_ids"].size(0))]
        output_tokens = []
        # print('self.mapping', self.mapping)

        if self.mapping:
            res = self._mapping_gen(self._mapping(mul_logits_weight, weights_attn))
        else:
            idx_lst = self.idx_lst if self.idx_lst else range(6)
            while done:  # only done when finish the whole batch.
                mul_logits_weight = []
                remain_idx = set()
                for idx in idx_lst:
                    encoder_outputs = t5_model.encoder(
                        inputs["input_ids"], vote_idx=idx
                    )
                    sequence_output = t5_model.decoder(
                        input_ids=decoder_input_ids.cuda(),
                        encoder_hidden_states=encoder_outputs[0],
                    )
                    lm_logits = t5_model.lm_head(
                        sequence_output[0]
                    )  # torch.Size([32, 1, 32100])
                    logits = F.softmax(lm_logits, -1)  # ([32, 32100])  # todo
                    next_word = logits.argmax(dim=-1)  # torch.Size([32])
                    # print('next_word', idx, next_word)
                    if len(next_word[0]) > 1:
                        for n_id in [int(n[-1]) for n in next_word]:
                            remain_idx.add(n_id)
                    else:
                        for n_id in [int(n) for n in next_word]:
                            remain_idx.add(n_id)

                    mul_logits_weight.append(lm_logits)
                # print('remain_idx', remain_idx)
                mul_logits_weight = torch.stack(mul_logits_weight, dim=1)
                if mul_logits_weight.size(2) > 1:
                    mul_logits_weight = mul_logits_weight[
                        :, :, count, :
                    ]  # torch.Size([32, 6, 3, 32100]), move the logits we need to weight.
                else:
                    mul_logits_weight = mul_logits_weight[:, :, 0, :]
                count += 1

                with torch.no_grad():
                    # print('>weights_attn', weights_attn)
                    logits = torch.einsum(
                        "bp, bpd -> bd", weights_attn, mul_logits_weight
                    )  # torch.Size([96, 32100])
                    # logits = F.softmax(logits, -1)  # ([32, 32100])

                    # only remain the important index.
                    ones = torch.zeros(logits.size()).cuda()
                    for i in range(logits.size(0)):
                        for j in remain_idx:
                            ones[i, j] = 1
                    logits = logits * ones

                    next_word = logits.argmax(dim=-1)  # torch.Size([32])
                    output_tokens.append(next_word)
                    decoder_input_ids = torch.cat(
                        (decoder_input_ids.cuda(), next_word.unsqueeze(-1)), dim=-1
                    )

                    # only done when finish the whole batch.
                    for idx in range(inputs["input_ids"].size(0)):
                        if idx not in done_update:
                            if int(next_word[idx]) == 1:
                                res[idx] = [
                                    int(output_token[idx].cpu().detach().numpy())
                                    for output_token in output_tokens
                                ]
                                res[idx].append(1)
                                res[idx] = map_vote(res[idx][:-1])

                                done_update.add(idx)
                    if len(done_update) == inputs["input_ids"].size(0) or count > 10:
                        break
        # res_max = max([len(r) for r in res])
        res = [[0] + r + [0] * (10 - len(r) - 1) for r in res]
        # res = [r[:r.index(1) + 1] + [0] * (10 - len(r[:r.index(1) + 1])) for r in res]
        print("res", res)
        return torch.tensor(res)

    def forward(self, t5_model, inputs, predict=False):
        logits = None
        # print('forward->')
        if "uw" in self.type or "test" in self.type or "load" in self.type:
            if not self.mapping:
                mul_logits, weights, _, _, hloss = self.get_weights(
                    t5_model, inputs, predict=predict
                )
                logits = torch.einsum(
                    "bp, bpd -> bd", weights, mul_logits
                )  # torch.Size([96, 32100])

            else:
                _, _, weights_attn, mul_logits_weight, hloss = self.get_weights(
                    t5_model, inputs, predict=predict
                )
                logits = torch.einsum(
                    "bp, bpd -> bd", weights_attn, mul_logits_weight
                )  # torch.Size([96, 32100])
                # print('logits', logits)

        elif self.type == "logit":
            mul_logits, weights, _, _, hloss = self.get_weights(
                t5_model, inputs, predict=predict
            )
            logits = torch.sum(torch.mul(mul_logits, weights), dim=0)

        elif self.type == 'test_majority':
            mul_logits, weights, _, _, hloss = self.get_weights(
                t5_model, inputs, predict=predict
            )
            logits = torch.sum(torch.mul(mul_logits, weights), dim=0)
        # logits = F.softmax(logits, -1)
        return logits, hloss  # torch.Size([96, 32100])


    def get_weights(self, t5_model, inputs, predict=False):
        weights_attn, mul_logits_weight, hloss = None, None, None

        if self.type == "uw":
            mul_logits = []
            mul_logits_weight = []
            idx_lst = self.idx_lst if self.idx_lst else range(6)

            for idx in idx_lst:
                output = t5_model(
                    **inputs, vote_idx=idx
                )  # vote_idx is the teacher idx.
                logits = output.logits.view(
                    -1, t5_model.vocab_size
                )  # TODO: Use 1st logit to decide weight now
                mul_logits_weight.append(output.logits)
                mul_logits.append(logits)

            # change dimensions from torch.Size([6, 96, 32100]) to torch.Size([96, 6, 32100])
            results_lst = []
            for idx in range(mul_logits[0].size(0)):
                res = []
                for i in range(mul_logits.size(0)):
                    res.append(mul_logits[i][idx])
                res = torch.stack(res)
                results_lst.append(res)
            mul_logits = torch.stack(results_lst)

            results_lst = (
                []
            )  # torch.Size([6, 32, 3, 32100]) -> torch.Size([32, 6, 3, 32100])
            for idx in range(mul_logits_weight[0].size(0)):
                res = []
                for i in range(mul_logits_weight.size(0)):
                    res.append(mul_logits_weight[i][idx])
                res = torch.stack(res)
                results_lst.append(res)
            mul_logits_weight = torch.stack(results_lst)

            # mul_logits = torch.stack(mul_logits).cuda()  # torch.Size([6, 96, 32100])
            # mul_logits_weight = torch.stack(mul_logits_weight).cuda()  # torch.Size([6, 32, 3, 32100])
            # avg_mul_logits_weight = torch.max(mul_logits_weight, 2)[0] # torch.Size([32, 6, 32100])
            # avg_mul_logits_weight = torch.max(avg_mul_logits_weight, 1)[0] # torch.Size([32, 32100])
            #  torch.Size([96, 6, 32100]) ->
            # avg_mul_logits_weight = torch.max(mul_logits, 2)[0]  # torch.Size([32, 6, 32100])

            avg_mul_logits_weight = torch.max(mul_logits, 1)[
                0
            ]  # torch.Size([96, 32100])
            weights = mul_logits.bmm(
                self.attention_net(avg_mul_logits_weight).unsqueeze(-1)
            )  # torch.Size([96, 6])
            weights = F.softmax(weights.squeeze(-1), 1)

        elif self.type == "uw-input":
            """
            uw attn architecture
            use input to decide the attn.
            """
            # print('self.type >>> uw-input')
            print('inputs', inputs.keys())
            for j in range(inputs['input_ids'].size(0)):
                # print('>>>', j, inputs['input_ids'][j].cuda())
                print('>>>', j, tokenizer.decode(inputs['input_ids'][j].cuda()).replace('<pad>', ''))
                # print('>>>', j, inputs['labels'][j].cuda())
                print('>>>', j, tokenizer.decode([x for x in inputs['labels'][j].cuda() if x >= 0]))

            # logits
            mul_logits = []
            mul_logits_weight = []
            idx_lst = self.idx_lst if self.idx_lst else range(6)

            for idx in idx_lst:
                output = t5_model(
                    **inputs, vote_idx=idx
                )  # vote_idx is the teacher idx.
                logits = output.logits.view(-1, t5_model.vocab_size)
                mul_logits_weight.append(output.logits[:, 0, :])
                mul_logits.append(logits)

            mul_logits_weight = torch.stack(mul_logits_weight, dim=1)
            mul_logits = torch.stack(mul_logits, dim=1)
            mul_logits_weight = self._mapping_logits(mul_logits_weight)
            # print('mul_logits_weight>>>', mul_logits_weight.size())
            # print(mul_logits_weight[0, :, 3], mul_logits_weight[0, :, 209])

            # weight
            avg_mul_prefix_emb, avg_inputs_embeds = t5_model.encoder.embed_input(
                **inputs
            )  # torch.Size([32, 6, 768]), torch.Size([32, 768]),
            x = self.attention_net_key(avg_inputs_embeds).unsqueeze(
                -1
            )  # torch.Size([32, 768, 1])

            # bmm (b×n×m) (b×m×p) (b×n×p)

            # self.train()
            weights_attn = (
                self.attention_net_query(mul_logits_weight).bmm(x).squeeze(-1)
            )

            # print('weights_attn', (mul_logits_weight))
            if self.hloss:
                criterion = HLoss(self.hloss)
                hloss = criterion(weights_attn)
            else:
                hloss = None

            # Important!!!
            weights_attn = F.softmax(weights_attn / 1.0, -1)
            self.debug = (mul_logits_weight, x)

            weights = torch.repeat_interleave(
                self.attention_net_query(mul_logits_weight).bmm(x).squeeze(-1),
                mul_logits.size(0) // x.size(0),
                dim=0,
            )  # torch.Size([96, 6])
            weights = F.softmax(weights, 1)

            if predict:
                print(">>>", weights_attn)

        elif self.type == "uw-input-attention_at":
            """
            uw attn architecture
            use input to decide the attn.
            """
            # logits
            mul_logits = []
            mul_logits_weight = []
            idx_lst = self.idx_lst if self.idx_lst else range(6)

            for idx in idx_lst:
                output = t5_model(
                    **inputs, vote_idx=idx
                )  # vote_idx is the teacher idx.
                logits = output.logits.view(-1, t5_model.vocab_size)
                mul_logits_weight.append(output.logits[:, 0, :])
                mul_logits.append(logits)

            mul_logits_weight = torch.stack(mul_logits_weight, dim=1)
            mul_logits = torch.stack(mul_logits, dim=1)
            mul_logits_weight = self._mapping_logits(mul_logits_weight)
            # print('mul_logits_weight>>>', mul_logits_weight.size())
            # print(mul_logits_weight[0, :, 3], mul_logits_weight[0, :, 209])

            # weight
            avg_mul_prefix_emb, avg_inputs_embeds = t5_model.encoder.embed_input(
                **inputs
            )  # torch.Size([32, 6, 768]), torch.Size([32, 768]),
            x = self.attention_net_key(avg_inputs_embeds).unsqueeze(
                -1
            )  # torch.Size([32, 768, 1])

            # bmm (b×n×m) (b×m×p) (b×n×p)

            # self.train()
            weights_attn = (
                self.attention_net_query(avg_mul_prefix_emb).bmm(x).squeeze(-1)
            )

            # print('weights_attn', (mul_logits_weight))
            if self.hloss:
                criterion = HLoss(self.hloss)
                hloss = criterion(weights_attn)
            else:
                hloss = None

            # Important!!!
            weights_attn = F.softmax(weights_attn / 1.0, -1)
            self.debug = (mul_logits_weight, x)

            weights = torch.repeat_interleave(
                self.attention_net_query(avg_mul_prefix_emb).bmm(x).squeeze(-1),
                mul_logits.size(0) // x.size(0),
                dim=0,
            )  # torch.Size([96, 6])
            weights = F.softmax(weights, 1)

            if predict:
                print(">>>", weights_attn)

        elif self.type == "logit":
            mul_logits = []
            idx_lst = self.idx_lst if self.idx_lst else range(6)

            for idx in idx_lst:
                output = t5_model(
                    **inputs, vote_idx=idx
                )  # vote_idx is the teacher idx.
                logits = output.logits.view(-1, t5_model.vocab_size)
                mul_logits.append(logits)
            mul_logits = torch.stack(mul_logits).cuda()
            weights = F.softmax(self.attention_net(mul_logits))

        # print(weights_attn.size(), weights_attn.size())

        # debug
        elif 'test' in self.type:
            # logits
            mul_logits = []
            mul_logits_weight = []
            idx_lst = self.idx_lst if self.idx_lst else range(6)

            for idx in idx_lst:
                output = t5_model(
                    **inputs, vote_idx=idx
                )  # vote_idx is the teacher idx.
                logits = output.logits.view(-1, t5_model.vocab_size)
                mul_logits_weight.append(output.logits[:, 0, :])
                mul_logits.append(logits)

            mul_logits_weight = torch.stack(mul_logits_weight, dim=1)
            mul_logits = torch.stack(mul_logits, dim=1)
            mul_logits_weight = self._mapping_logits(mul_logits_weight)

            # weights
            print('mul_logits_weight', mul_logits_weight.size())
            weights_attn = torch.zeros((mul_logits_weight.size(0), mul_logits_weight.size(1)))

            for i in range(weights_attn.size(-1)):
                weights_attn[:, i] = 1.0 / mul_logits_weight.size(1)

            # weights_attn[:, 2] = 1.
            weights_attn = weights_attn.cuda()
            weights = None

        elif 'load' in self.type:
            # logits
            mul_logits = []
            mul_logits_weight = []
            idx_lst = self.idx_lst if self.idx_lst else range(6)

            for idx in idx_lst:
                # model.to(device)
                output = t5_model(
                    **inputs, vote_idx=idx
                )  # vote_idx is the teacher idx.
                logits = output.logits.view(-1, t5_model.vocab_size)
                mul_logits_weight.append(output.logits[:, 0, :])
                mul_logits.append(logits)

            mul_logits_weight = torch.stack(mul_logits_weight, dim=1)
            mul_logits = torch.stack(mul_logits, dim=1)
            mul_logits_weight = self._mapping_logits(mul_logits_weight)

            # weights
            import pickle

            if 'idx' not in self.type:
                with open(self.file_path, "rb") as handle:
                    acc_results_dict = pickle.load(handle)
                weights = []
                source_dataset_names = ["mnli", "sst2", "qnli", "qqp", "squad", "record"]
                for dataset_source in source_dataset_names:
                    weights.append(
                        acc_results_dict[self.target_task][dataset_source][self.seed_id]
                    )
                weights = torch.tensor([weights]).cuda()
                x = layer_norm(weights)
                # weight
                weights = torch.softmax(x, dim=-1)

                weights_attn = weights.repeat(mul_logits_weight.size(0), 1)
            else:
                # weights
                weights_attn = torch.zeros((mul_logits_weight.size(0), 6))
                with open(self.file_path, "rb") as handle:
                    idx_lst = pickle.load(handle)[self.seed_id]
                    for id in idx_lst:
                         weights_attn[:, id] = 1.0 / len(idx_lst)
                weights_attn = weights_attn.cuda()
                weights, x = None, None

            # hloss
            if self.hloss:
                criterion = HLoss()
                hloss = criterion(x)
            else:
                hloss = None

        else:
            mul_logits, weights = None, None
        print('weights_attn', weights_attn)

        return mul_logits, weights, weights_attn, mul_logits_weight, hloss


class Seq2SeqTrainer_vote(Seq2SeqTrainer, BaseTrainer):
    """
    This is the new trainer for weighted voting.
    1. We rewrite the ''compute loss'' function.
    2. We modify the genrate function.
    """

    def __init__(
        self,
        t5_model,
        train_dataset_sizes=None,
        shared=False,
        adapter_config=None,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.adapter_config = adapter_config
        self.train_dataset_sizes = train_dataset_sizes
        self.shared = shared
        # self.models = models
        self.device = torch.device("cuda")
        self.t5_model = t5_model.to(self.device)

    def evaluate(
        self,
        eval_dataset: Optional[Dict[str, Dataset]] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
        max_length: Optional[int] = None,
        num_beams: Optional[int] = None,
    ) -> Dict[str, float]:

        # TODO: this also needs to be set per dataset
        self._max_length = max_length
        self._num_beams = num_beams
        return super().evaluate(
            eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )

    def compute_loss(self, model, inputs, return_outputs=False, predict=False):
        """
        1. get the 6 by x by vocab_size mul logits
        2. Use Attn to weight avg it.
        3. Calculate cross- entropy loss
        """
        logits, hloss = model.forward(self.t5_model, inputs, predict=predict)
        loss_fct = nn.CrossEntropyLoss(ignore_index=-100).cuda()

        if not model.mapping:
            loss = loss_fct(logits, inputs["labels"].view(-1))
        else:
            loss = loss_fct(
                logits.view(-1, self.t5_model.vocab_size),
                inputs["labels"][:, 0].view(-1),
            )
            print("loss", loss, ">>>>", hloss)
            if hloss is not None:
                loss += hloss
                print("loss", loss)

            # print(inputs['labels'][:,0].view(-1))
            # print(logits[:,209], logits[:,3])
            # print(inputs['input_ids'][0])
            # debug
        #    for index in range(32):
        #        if list(inputs['input_ids'][index].cpu().detach().numpy()) == [ 1355, 13492,    15,    18,   210,   447,  7142,   536,    10,  2372,
        #  689,  6109,   192,   979,    30,     3,     9,  6112,  7797,   113,
        # 6073,    30,    24,  6112,     5,  7142,   357,    10,    37,  4566,
        # 8337,   261,     3,     9,  6112,    21,     8,  1992,   161,     5,
        # 1448,    10,  6112,     1,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0]:
        #            break
        #
        #
        #    # print('index --->', index)
        #    print(inputs['labels'][index])
        #    print(logits[index, 209], logits[index, 3])
        #    # print(inputs['input_ids'][1])
        #
        #    for index in range(32):
        #        if list(inputs['input_ids'][index].cpu().detach().numpy()) == [ 1355, 13492,    15,    18,   210,   447,  7142,   536,    10,  4330,
        #    8,  1487,     5,  7142,   357,    10,  4330,    46,   905,     5,
        # 1448,    10,   610,     1,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        #    0,     0,     0,     0,     0,     0]:
        #            break
        #
        #    print('index --->', index)
        #    print(inputs['labels'][index])
        #    print(logits[index, 209], logits[index, 3])
        #    # print(inputs['input_ids'][1])

        # print('loss->>>', loss)
        return loss

    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
            labels (each being optional).
        """
        if not self.args.predict_with_generate or prediction_loss_only:
            return super().prediction_step(
                self.t5_model,
                inputs,
                prediction_loss_only=prediction_loss_only,
                ignore_keys=ignore_keys,
            )

        has_labels = "labels" in inputs
        inputs = self._prepare_inputs(inputs)
        gen_kwargs = {
            "max_length": self._max_length
            if self._max_length is not None
            else self.t5_model.config.max_length,
            "num_beams": self._num_beams
            if self._num_beams is not None
            else self.t5_model.config.num_beams,
            "task": inputs["task"] if "task" in inputs else "all",
        }

        self.t5_model.eval()
        with torch.no_grad():
            generated_tokens = model.generate(self.t5_model, inputs, gen_kwargs)

        if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
            generated_tokens = self._pad_tensors_to_max_len(
                generated_tokens, gen_kwargs["max_length"]
            )

        loss = None

        # replace loss with vote loss
        with torch.no_grad():
            if has_labels:
                # print('preiction--->')
                # loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
                loss = self.compute_loss(model, inputs, predict=True).mean().detach()
                # logits = model.forward(self.t5_model, inputs)
                # loss_fct = nn.CrossEntropyLoss().cuda()
                # if not model.mapping:
                #     loss = loss_fct(logits, inputs['labels'].view(-1))
                # else:
                #     loss = loss_fct(logits, inputs['labels'][:, 0].view(-1))
        # print('done predict')
        if self.args.prediction_loss_only:
            return (loss, None, None)

        labels = inputs["labels"].cuda()
        if labels.shape[-1] < gen_kwargs["max_length"]:
            labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])

        return (loss, generated_tokens, labels)

    def _remove_unused_columns(
        self, dataset: "datasets.Dataset", description: Optional[str] = None
    ):
        if not self.args.remove_unused_columns:
            return dataset
        if self._signature_columns is None:
            # Inspect model forward signature to keep only the arguments it accepts.
            signature = inspect.signature(self.t5_model.forward)
            self._signature_columns = list(signature.parameters.keys())
            # Labels may be named label or label_ids, the default data collator handles that.
            self._signature_columns += ["label", "label_ids"]
        columns = [k for k in self._signature_columns if k in dataset.column_names]
        ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
        if len(ignored_columns) > 0:
            dset_description = (
                "" if description is None else f"in the {description} set "
            )
            # logger.info(
            #     f"The following columns {dset_description} don't have a corresponding argument in "
            #     f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
            # )

        # if version.parse(datasets.__version__) < version.parse("1.4.0"):
        #     dataset.set_format(
        #         type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
        #     )
        #     return dataset
        # else:
        return dataset.remove_columns(ignored_columns)
