from __future__ import absolute_import, division, print_function
import argparse
import csv
import logging
import os
from tqdm import tqdm, trange
import random
import copy
import numpy as np

import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset

from transformers import AutoTokenizer, RobertaConfig
from transformers import AdamW, get_linear_schedule_with_warmup

import sklearn.metrics as mtc
from scipy.stats import spearmanr

from modeling_bert import BertForSequenceClassification
from modeling_roberta import RobertaForSequenceClassification
from modeling_generator_for_roberta import Generator


logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)


class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, label=None):
        """Constructs a InputExample.

        Args:
            guid: Unique id for the example.
            text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
            text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
            label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_id):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id


class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, 'r', encoding="utf-8") as f:
            reader = csv.reader(f, delimiter='\t', quotechar=quotechar)
            lines = []
            for line in reader:
                lines.append(line)
            return lines


class ImdbProcessor(DataProcessor):
    """Processor for the IMdB data set (GLUE version)."""

    def get_train_examples(self, data_dir):
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.tsv")), "dev")

    def get_test_examples(self, data_dir):
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_aug_examples(self, data_dir):
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train_au.tsv")), "train_au")

    def get_labels(self):
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            text_a = line[-2]
            label = line[-1]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        return examples


class RteProcessor(DataProcessor):
    """Processor for the RTE data set (GLUE version)."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return ["entailment", "not_entailment"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            if set_type == "test":
                text_a = line[-2]
                text_b = line[-1]
                label = "entailment"
            else:
                text_a = line[-3]
                text_b = line[-2]
                label = line[-1]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


class SstProcessor(DataProcessor):
    """Processor for the SST-2 data set (GLUE version)."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            if set_type == "test":
                text_a = line[-1]
                label = "0"
            else:
                text_a = line[-2]
                label = line[-1]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        return examples


class MrpcProcessor(DataProcessor):
    """Processor for the MRPC data set (GLUE version)."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            if set_type == "test":
                text_a = line[-2]
                text_b = line[-1]
                label = "0"
            else:
                text_a = line[-2]
                text_b = line[-1]
                label = line[0]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


class QnliProcessor(DataProcessor):
    """Processor for the QNLI data set (GLUE version)."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return ["entailment", "not_entailment"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            if set_type == "test":
                text_a = line[-2]
                text_b = line[-1]
                label = "entailment"
            else:
                text_a = line[-3]
                text_b = line[-2]
                label = line[-1]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


class MnliProcessor(DataProcessor):
    """Processor for the MultiNLI data set (GLUE version)."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_matched")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return ["contradiction", "entailment", "neutral"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
              continue
            guid = "%s-%s" % (set_type, line[0])
            if set_type == "test":
                text_a = line[-3]
                text_b = line[-2]
                label = "contradiction"
            elif set_type == "diag":
                text_a = line[-2]
                text_b = line[-1]
                label = "contradiction"
            elif set_type == "dev_matched":
                text_a = line[-8]
                text_b = line[-7]
                label = line[-1]
            else:
                text_a = line[-4]
                text_b = line[-3]
                label = line[-1]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


class QqpProcessor(DataProcessor):
    """Processor for the QQP data set (GLUE version)."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            try:
                if set_type == "test":
                    text_a = line[-2]
                    text_b = line[-1]
                    label = "0"
                else:
                    if set_type == "train" and (i == 310122 or i == 362226):
                        continue
                    text_a = line[-3]
                    text_b = line[-2]
                    label = line[-1]
                    if label not in ["0", "1"]:
                        continue
            except IndexError:
                continue
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples


class ColaProcessor(DataProcessor):
    """Processor for the CoLA data set (GLUE version)."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if set_type == "test" and i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            if set_type == "test":
                text_a = line[-1]
                label = "0"
            else:
                text_a = line[-1]
                label = line[-3]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        return examples


def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
    """Loads a data file into a list of `InputBatch`s."""

    label_map = {label: i for i, label in enumerate(label_list)}
    features = []
    for i, example in enumerate(examples):
        if example.text_b:
            encoded_input = tokenizer(example.text_a,
                                      example.text_b,
                                      padding="max_length",
                                      truncation=True,
                                      max_length=max_seq_length)
            input_ids = encoded_input["input_ids"]
            input_mask = encoded_input["attention_mask"]
            segment_ids = [0] * len(input_ids)
            tokens = tokenizer.convert_ids_to_tokens(input_ids)
        else:
            encoded_input = tokenizer(example.text_a,
                                      padding="max_length",
                                      truncation=True,
                                      max_length=max_seq_length)
            input_ids = encoded_input["input_ids"]
            input_mask = encoded_input["attention_mask"]
            segment_ids = [0] * len(input_ids)
            tokens = tokenizer.convert_ids_to_tokens(input_ids)

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        label_id = label_map[example.label]
        if i < 5:
            logger.info("*** Example ***")
            logger.info("guid: %s" % example.guid)
            logger.info("tokens: %s" % " ".join(tokens))
            logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
            logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
            logger.info("label: %s (id = %s)" % (example.label, label_id))

        features.append(
                InputFeatures(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              label_id=label_id)
        )
    return features


def tensify(features):
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)

    return TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)


def accuracy(out, labels):
    outputs = np.argmax(out, axis=1)
    return np.sum(outputs == labels)


def mcc(out, labels):
    return mtc.matthews_corrcoef(out, labels)


def spc(out, labels):
    return spearmanr(out, labels)[0]


def gumbel_sigmoid(logits, training=True, tau=1e-10):
    if not training:
        return (logits / tau).sigmoid()

    g1 = -torch.empty_like(logits).exponential_().log()
    g2 = -torch.empty_like(logits).exponential_().log()
    gumbels = (logits + g1 - g2) / tau

    return gumbels.sigmoid()


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--data_dir", type=str, default="data/",
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--model_type", type=str, default="roberta-base",
                        help="Bert pre-trained model.")
    parser.add_argument("--task_name", type=str, default="sst",
                        help="The name of the task to train.")
    parser.add_argument("--output_dir", type=str, default="model/",
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--best_epochs", type=float, default=1.0,
                        help="Best training epochs for prediction.")
    parser.add_argument("--cache_dir", type=str, default="cache/",
                        help="Directory to store the pre-trained models downloaded from s3.")
    parser.add_argument("--max_seq_length", type=int, default=128,
                        help="The maximum total input sequence length after WordPiece tokenization.")
    parser.add_argument("--do_train", action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval", action='store_true', default=True,
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_test", action='store_true',
                        help="Whether to run eval on the test set.")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size", type=int, default=32,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size", type=int, default=128,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate", type=float, default=1e-5,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs", type=float, default=3.0,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_rate", type=float, default=0.06,
                        help="Proportion of training to perform linear learning rate warmup for. ")
    parser.add_argument("--weight_decay", type=float, default=0.01,
                        help="L2 weight decay for training.")
    parser.add_argument("--no_cuda", action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed', type=int, default=2021,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale', type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.")
    parser.add_argument("--step", type=int, default=15,
                        help="Optimizing steps for the generator.")
    parser.add_argument("--generator_learning_rate", type=float, default=5e-5,
                        help="Learning_rate for the generator.")
    parser.add_argument("--share_embedding", action='store_true',
                        help="Whether to share embeddings for g_net and a_net.")

    args = parser.parse_args()

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "sst": SstProcessor,
        "qqp": QqpProcessor,
        "qnli": QnliProcessor,
        "rte": RteProcessor,
        "imdb": ImdbProcessor,
    }

    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    n_gpu = torch.cuda.device_count()
    logger.info("device: {} n_gpu: {}, 16-bits training: {}".format(device, n_gpu, args.fp16))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()
    if task_name not in processors:
        raise ValueError("Task not found: %s" % task_name)
    processor = processors[task_name]()
    label_list = processor.get_labels()
    num_labels = len(label_list)

    cache_dir = args.cache_dir if args.cache_dir else None
    tokenizer = AutoTokenizer.from_pretrained(args.model_type, do_lower_case=args.do_lower_case, cache_dir=cache_dir)

    if args.do_eval:
        eval_examples = processor.get_dev_examples(os.path.join(args.data_dir, args.task_name))
        eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer)
        eval_data = tensify(eval_features)
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

    if args.model_type == "bert-base-uncased":
        d_net = BertForSequenceClassification.from_pretrained(args.model_type,
                                                              num_labels=num_labels,
                                                              return_dict=True,
                                                              cache_dir=cache_dir)
    else:
        d_net = RobertaForSequenceClassification.from_pretrained(args.model_type,
                                                                 num_labels=num_labels,
                                                                 return_dict=True,
                                                                 cache_dir=cache_dir)
    d_net.to(device)
    if n_gpu > 1:
        d_net = torch.nn.DataParallel(d_net)

    if args.do_train:
        train_examples = processor.get_train_examples(os.path.join(args.data_dir, args.task_name))
        train_features = convert_examples_to_features(train_examples, label_list, args.max_seq_length, tokenizer)
        train_data = tensify(train_features)
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
        warmup_steps = num_train_optimization_steps * args.warmup_rate
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        param_optimizer = list(d_net.named_parameters())
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                "weight_decay": args.weight_decay
            },
            {
                "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0
            }
        ]
        d_optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
        d_scheduler = get_linear_schedule_with_warmup(d_optimizer, num_warmup_steps=warmup_steps,
                                                      num_training_steps=num_train_optimization_steps)

        a_net = copy.deepcopy(d_net)
        a_net.to(device)
        param_optimizer = list(a_net.named_parameters())
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                "weight_decay": args.weight_decay
            },
            {
                "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0
            }
        ]
        a_optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
        a_scheduler = get_linear_schedule_with_warmup(a_optimizer, num_warmup_steps=warmup_steps,
                                                      num_training_steps=num_train_optimization_steps)

        config = RobertaConfig().from_pretrained(args.model_type)
        config.hidden_size = 768
        config.num_attention_heads = 1
        g_net = Generator(config)
        if args.share_embedding:
            g_net.set_input_embeddings(a_net.module.roberta.embeddings.word_embeddings)\
                if hasattr(d_net, "module") else g_net.set_input_embeddings(a_net.roberta.embeddings.word_embeddings)
        g_net.to(device)
        if n_gpu > 1:
            g_net = torch.nn.DataParallel(g_net)
        g_optimizer = AdamW(g_net.parameters(), lr=args.generator_learning_rate)

    if args.do_train:
        t = args.step
        global_step = 0
        for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
            d_net.train()
            a_net.train()
            g_net.train()
            pool = []
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                outputs = d_net(input_ids=input_ids,
                                attention_mask=input_mask,
                                token_type_ids=segment_ids,
                                labels=label_ids)
                loss = outputs.loss
                if n_gpu > 1:
                    loss = loss.mean()
                loss.backward()

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    d_optimizer.step()
                    d_optimizer.zero_grad()
                    d_scheduler.step()
                    global_step += 1

                pool.append(batch)

                if (step + 1) % t == 0 or step == len(train_dataloader) - 1:
                    for batch in pool:
                        input_ids, input_mask, segment_ids, label_ids = batch
                        with torch.no_grad():
                            g_scores = g_net(input_ids=input_ids,
                                             attention_mask=input_mask,
                                             token_type_ids=segment_ids,
                                             output_attentions=True)
                        g_mask = [gumbel_sigmoid(g) for g in g_scores]
                        outputs = a_net(input_ids=input_ids,
                                        attention_mask=input_mask,
                                        token_type_ids=segment_ids,
                                        labels=label_ids,
                                        generator_output=g_mask)
                        loss = outputs.loss
                        if n_gpu > 1:
                            loss = loss.mean()
                        loss.backward()

                        if (step + 1) % args.gradient_accumulation_steps == 0:
                            a_optimizer.step()
                            a_optimizer.zero_grad()
                            a_scheduler.step()

                    # Evaluate and reward.
                    d_score, a_score = 0, 0
                    num_train_examples = 0
                    probs = []
                    del pool[:]
                    sample_ids = set(random.sample(list(range(len(eval_dataloader))), args.step))
                    pool = []
                    for i, batch in enumerate(eval_dataloader):
                        if i in sample_ids:
                            batch = tuple(t.to(device) for t in batch)
                            pool.append(batch)
                    for batch in pool:
                        # Evaluate defender.
                        input_ids, input_mask, segment_ids, label_ids = batch
                        with torch.no_grad():
                            outputs = d_net(input_ids=input_ids,
                                            attention_mask=input_mask,
                                            token_type_ids=segment_ids,
                                            labels=label_ids)
                            logits = outputs.logits

                        logits = logits.detach().cpu().numpy()
                        label_ids = label_ids.to("cpu").numpy()
                        d_score += accuracy(logits, label_ids)
                        num_train_examples += input_ids.size(0)

                        # Evaluate attacker.
                        input_ids, input_mask, segment_ids, label_ids = batch
                        g_scores = g_net(input_ids=input_ids,
                                         attention_mask=input_mask,
                                         token_type_ids=segment_ids,
                                         output_attentions=True)
                        g_mask = [gumbel_sigmoid(g) for g in g_scores]
                        with torch.no_grad():
                            outputs = a_net(input_ids=input_ids,
                                            attention_mask=input_mask,
                                            token_type_ids=segment_ids,
                                            labels=label_ids)
                            logits = outputs.logits

                        logits = logits.detach().cpu().numpy()
                        label_ids = label_ids.to("cpu").numpy()
                        a_score += accuracy(logits, label_ids)

                        layer_probs = []
                        for s, m in zip(g_scores, g_mask):
                            layer_probs.append([s.sigmoid(), m])
                        probs.append(layer_probs)

                    rewards = [0] * args.step
                    rewards[-1] = -(d_score - a_score)
                    returns = []
                    ret = 0
                    for r in rewards[::-1]:
                        ret = r + .99 * ret
                        returns.insert(0, ret)

                    g_loss = 0
                    for i, (lp, r) in enumerate(zip(probs, returns)):
                        layer_loss = []
                        for p, m in lp:
                            l_1 = -r * (p + 1e-10).log() * (m > 0.5)
                            l_0 = -r * (1 - p + 1e-10).log() * (m <= 0.5)
                            layer_loss += [l_1 + l_0]
                        extended_attention_mask = pool[i][1][:, None, None, :]
                        g_loss += sum([(ll * extended_attention_mask).sum() / ll.numel() for ll in layer_loss]) / 12
                    g_loss /= args.step

                    g_loss.backward()
                    g_optimizer.step()
                    g_optimizer.zero_grad()

                    del rewards[:]
                    del returns[:]
                    del probs[:]
                    del pool[:]

                    # print(d_score, a_score)
                    if random.random() >= d_score / (d_score + a_score):
                        d_net.load_state_dict(a_net.state_dict())
                    else:
                        a_net.load_state_dict(d_net.state_dict())
                    pool = []

            if args.do_eval and (epoch + 1) % args.best_epochs == 0:
                logger.info("***** Running evaluation *****")
                logger.info("  Num examples = %d", len(eval_examples))
                logger.info("  Batch size = %d", args.eval_batch_size)
                d_net.eval()
                g_net.eval()
                eval_accuracy = 0
                eval_loss = 0
                num_eval_examples = 0
                total_pred, total_labels = [], []
                for batch in tqdm(eval_dataloader, desc="Evaluating"):
                    batch = tuple(t.to(device) for t in batch)
                    input_ids, input_mask, segment_ids, label_ids = batch
                    with torch.no_grad():
                        outputs = d_net(input_ids=input_ids,
                                        attention_mask=input_mask,
                                        token_type_ids=segment_ids,
                                        labels=label_ids)
                        logits = outputs.logits
                        tmp_eval_loss = outputs.loss

                    logits = logits.detach().cpu().numpy()
                    label_ids = label_ids.to("cpu").numpy()
                    eval_accuracy += accuracy(logits, label_ids)
                    eval_loss += tmp_eval_loss.mean().item()
                    num_eval_examples += input_ids.size(0)
                    if task_name == "cola":
                        total_pred.extend(np.argmax(logits, axis=1).squeeze().tolist())
                        total_labels.extend(label_ids.squeeze().tolist())

                eval_accuracy = eval_accuracy / num_eval_examples * 100
                if task_name == "cola":
                    eval_mcc = mcc(total_pred, total_labels) * 100

                # Saving G-Net.
                generator_to_save = g_net.module if hasattr(g_net, "module") else g_net
                output_generator_file = os.path.join(args.output_dir, "{}_generator.bin".format(int(epoch)))
                torch.save(generator_to_save.state_dict(), output_generator_file)
                # Saving D-Net.
                model_to_save = d_net.module if hasattr(d_net, "module") else d_net
                output_model_file = os.path.join(args.output_dir, "{}_defender.bin".format(int(epoch)))
                torch.save(model_to_save.state_dict(), output_model_file)

                output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
                with open(output_eval_file, 'a') as writer:
                    logger.info("***** Eval results *****")
                    if task_name != "cola":
                        logger.info("Epoch %s: accuracy = %.2f | loss = %.3f\n" %
                                    (str(epoch), eval_accuracy, eval_loss))
                        writer.write("Epoch %s: accuracy = %.2f | loss = %.3f\n" %
                                     (str(epoch), eval_accuracy, eval_loss))
                    else:
                        logger.info("Epoch %s: accuracy = %.2f | mcc = %.2f | loss = %.3f\n" %
                                    (str(epoch), eval_accuracy, eval_mcc, eval_loss))
                        writer.write("Epoch %s: accuracy = %.2f | mcc = %.2f | loss = %.3f\n" %
                                     (str(epoch), eval_accuracy, eval_mcc, eval_loss))


if __name__ == "__main__":
    main()
