import logging
import sys
from typing import Optional
import pdb

import numpy as np
import os
import transformers
from dataclasses import dataclass, field
from datasets import load_dataset, load_metric
from transformers import (
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    PretrainedConfig,
    default_data_collator,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process
from transformers.utils import check_min_version
from scipy.special import softmax
from transformers.utils.dummy_tokenizers_objects import PreTrainedTokenizerFast

from robustSSANs.models.configuration_auto import AutoConfig
from robustSSANs.models.modeling_roberta import RobertaForSequenceClassificationWithCE, RobertaForSequenceClassificationWithKL
from robustSSANs.utils.mytrainer import MyTrainer
from robustSSANs.utils.training_args import TrainingArguments


sparse_first = 1

model_cls = RobertaForSequenceClassificationWithKL if sparse_first else RobertaForSequenceClassificationWithCE
model = model_cls.from_pretrained(
    'roberta-base',
    from_tf=bool(".ckpt" in 'roberta-base'),
    cache_dir=model_args.cache_dir,
    config=config,
    revision=model_args.model_revision,
    use_auth_token=True if model_args.use_auth_token else None,
)