# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
using a masked language modeling (MLM) loss.
"""

from __future__ import absolute_import, division, print_function


import pdb
import argparse
import glob
import logging

import pickle
import random
import time
import os,sys,inspect

import numpy as np
import torch
import torch.nn.init as init
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter
from tqdm import tqdm, trange
from collections import defaultdict

# from azure.cosmosdb.table.tableservice import TableService
# from azure.cosmosdb.table.models import Entity
from datetime import datetime

from nltk import tokenize


from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
                                  BertConfig, BertForLatentConnector, BertTokenizer,
                                  GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer,
                                  OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
                                  RobertaConfig, RobertaForMaskedLM, RobertaTokenizer,
                                  GPT2SequenceClassification, GPT2Encoder)

from utils import (weight_init, calc_iwnll, calc_rec, calc_mi, calc_au,
                   BucketingDataLoader, BucketingDataLoader_Semi,
                   TextDataset_Split, TextDataset_2Tokenizers, frange_cycle_linear, frange_cycle_zero_linear)


from modules import VAEClasBak as VAEClas

#currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
#parentdir = os.path.dirname(currentdir)
#sys.path.insert(0,parentdir)
#from run_lm_finetuning import sample_sequence_conditional_batch
from run_latent_generation import sample_sequence_conditional_batch

import texar.tf as tx

# logging.getLogger("azure").setLevel(logging.WARNING)
# logging.getLogger("TableService").setLevel(logging.WARNING)

logger = logging.getLogger(__name__)


MODEL_CLASSES = {
    'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
    'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
    'bert': (BertConfig, BertForLatentConnector, BertTokenizer),
    'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
    'gpt2-clas': (GPT2Config, GPT2SequenceClassification, GPT2Tokenizer),
    'gpt2-encoder': (GPT2Config, GPT2Encoder, GPT2Tokenizer)
}


def is_integer(s):
    try:
        int(s)
        return True
    except ValueError:
        return False


def load_and_cache_examples(args, tokenizer, evaluate=False, path=None, text_split_mode='natural'):
    if path is None:
        path = args.eval_data_file if evaluate else args.train_data_file
    if isinstance(tokenizer, list):
        dataset = TextDataset_2Tokenizers(tokenizer, args, file_path=path, text_split_mode=text_split_mode, block_size=args.block_size)
    else:
        dataset = TextDataset_Split(tokenizer, args, file_path=path, text_split_mode=text_split_mode, block_size=args.block_size)
    return dataset


def build_dataload_and_cache_examples(args, tokenizer, evaluate=False, shuffle=False):
    if isinstance(tokenizer, list):
        if not evaluate:
            args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
            file_path=args.train_data_file
            file_path_2=args.train_data_file_2
        else:
            args.batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
            file_path=args.eval_data_file

        if args.semi and not evaluate:
            dataloader = BucketingDataLoader_Semi(file_path, file_path_2,
                                                  args.batch_size, args.per_gpu_train_batch_size_2 * max(1, args.n_gpu),
                                                  args.max_seq_length, tokenizer, args, bucket=100, shuffle=shuffle)
        else:
            dataloader = BucketingDataLoader(file_path, args.batch_size, args.max_seq_length, tokenizer, args, bucket=100, shuffle=shuffle)
    else:
        pass

    return dataloader


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


def mask_tokens(inputs, tokenizer, args):
    """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
    labels = inputs.clone()
    # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)

    masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).to(torch.uint8)
    labels[masked_indices==1] = -1  # We only compute loss on masked tokens

    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).to(torch.uint8) & masked_indices
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).to(torch.uint8) & masked_indices & ~indices_replaced
    indices_random = indices_random
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # The rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels

def weights_init_rondom(model):
    model = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
    model_state_dict = model.state_dict()
    for key in model_state_dict:
        pdb.set_trace()
        if 'encoder' in key:
            init.normal_(model_state_dict[key].data)
        # weight_init(item)

def save_checkpoint(model_vae, optimizer, global_step, args):

    # Create output directory if needed
    # Save model checkpoint
    output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(global_step))
    output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(global_step))
    if not os.path.exists(output_encoder_dir) and args.local_rank in [-1, 0]:
        os.makedirs(output_encoder_dir)
    if not os.path.exists(output_decoder_dir) and args.local_rank in [-1, 0]:
        os.makedirs(output_decoder_dir)

    logger.info("Saving encoder model checkpoint to %s", output_encoder_dir)
    logger.info("Saving decoder model checkpoint to %s", output_decoder_dir)
    # Save a trained model, configuration and tokenizer using `save_pretrained()`.
    # They can then be reloaded using `from_pretrained()`

    model_encoder_to_save = model_vae.module.encoder if hasattr(model_vae, 'module') else model_vae.encoder  # Take care of distributed/parallel training
    model_decoder_to_save = model_vae.module.decoder if hasattr(model_vae, 'module') else model_vae.decoder  # Take care of distributed/parallel training

    # Good practice: save your training arguments together with the trained model
    if args.use_philly:
        save_solid = False
        while not save_solid:
            try:
                model_encoder_to_save.save_pretrained(output_encoder_dir)
                torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))
                save_solid = True
            except:
                pass
    else:
        model_encoder_to_save.save_pretrained(output_encoder_dir)
        torch.save(args, os.path.join(output_encoder_dir, 'training_encoder_args.bin'))


    if args.use_philly:
        save_solid = False
        while not save_solid:
            try:
                model_decoder_to_save.save_pretrained(output_decoder_dir)
                torch.save(args, os.path.join(output_decoder_dir, 'training_decoder_args.bin'))
                save_solid = True
            except:
                pass
    else:
        model_decoder_to_save.save_pretrained(output_decoder_dir)
        torch.save(args, os.path.join(output_decoder_dir, 'training_encoder_args.bin'))


    # save the full model and optmizer into a checkpoint
    model_to_save = model_vae.module if hasattr(model_vae, 'module') else model_vae  # Take care of distributed/parallel training

    checkpoint = {
    'iter': global_step,
    'model_state_dict': model_to_save.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'beta': model_to_save.args.beta,
    'args': args
    }

    output_full_dir = os.path.join(args.output_dir, 'checkpoint-full-{}'.format(global_step))
    if not os.path.exists(output_full_dir) and args.local_rank in [-1, 0]:
        os.makedirs(output_full_dir)

    logger.info("Start saving full model checkpoint to %s", output_full_dir)
    if args.use_philly:
        save_solid = False
        n_save_attempts = 0
        while not save_solid:
            try:
                n_save_attempts += 1
                logger.info(f"Saving full checkpoint: {n_save_attempts} attempts made")
                torch.save(checkpoint, os.path.join(output_full_dir, 'training.bin'))
                logger.info("Saving full checkpoint to %s,", output_full_dir)
                save_solid = True
            except:
                pass
    else:
        torch.save(checkpoint, os.path.join(output_full_dir, 'training.bin'))
        logger.info("Saving full checkpoint to %s", output_full_dir)



def load_checkpoint(args, loading_step=None):

    args.encoder_model_type = args.encoder_model_type.lower()
    args.decoder_model_type = args.decoder_model_type.lower()
    if loading_step:
        global_step = args.gloabl_step_eval
    else:
        global_step = args.gloabl_step_eval

    output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))
    output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
    output_full_dir    = os.path.join(args.checkpoint_dir, 'checkpoint-full-{}'.format(global_step))

    checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
    logger.info("Evaluate the following checkpoints: %s", checkpoints)

    # Load a trained Encoder model and vocabulary
    encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
    model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)
    tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)

    model_encoder.to(args.device)
    if args.block_size <= 0:
        args.block_size = tokenizer_encoder.max_len_single_sentence  # Our input block size will be the max possible for the model
    args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)

    # Load a trained Decoder model and vocabulary
    decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
    model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)
    tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
    model_decoder.to(args.device)
    if args.block_size <= 0:
        args.block_size = tokenizer_decoder.max_len_single_sentence  # Our input block size will be the max possible for the model
    args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)

    # Load full model
    checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'))


def get_summarywriter(output_path):
    output_path = output_path[:-1] if output_path.endswith('/') else output_path
    _, output_foldername = os.path.split(output_path)

    import socket
    from datetime import datetime
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    logdir = os.path.join(
        #'runs', output_foldername + '_' + current_time + '_' + socket.gethostname())
        'runs_bias', output_foldername + '_' + current_time + '_' + socket.gethostname())
    return SummaryWriter(logdir=logdir)


def train(args, train_dataloader, model_vae, encoder_tokenizer, decoder_tokenizer, classifier_tokenizer, table_name):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        #tb_writer = SummaryWriter()
        tb_writer = get_summarywriter(args.output_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    # train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    # train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)


    # model_encoder, model_decoder, model_connector = model_vae.encoder,  model_vae.decoder, model_vae.linear
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model_vae.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model_vae.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model_vae, optimizer = amp.initialize(model_vae, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model_vae = torch.nn.DataParallel(model_vae, device_ids=range(args.n_gpu)).to(args.device)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", train_dataloader.num_examples)
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    tr_loss_clas, logging_loss_clas = 0.0, 0.0
    tr_loss_reg_z, logging_loss_reg_z = 0.0, 0.0
    tr_loss_c, logging_loss_c = 0.0, 0.0
    tr_loss_reg_z_c, logging_loss_reg_z_c = 0.0, 0.0
    tr_loss_recon, logging_loss_recon = 0.0, 0.0
    model_vae.zero_grad()


    # evaluate_clas(args, model_vae.classifier, encoder_tokenizer, decoder_tokenizer, classifier_tokenizer)  # 

    # model_vae = model_vae.module if hasattr(model_vae, 'module') else model_vae  # Take care of distributed/parallel training

    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])

    n_iter = int(args.num_train_epochs) * len(train_dataloader)
    # 
    #beta_t_list = frange_cycle_zero_linear(n_iter, start=0.0, stop=args.beta,  n_cycle=1, ratio_increase=args.ratio_increase, ratio_zero=args.ratio_zero)
    beta_t_list = frange_cycle_zero_linear(n_iter, start=0.0, stop=args.beta,  n_cycle=int(args.num_train_epochs), ratio_increase=args.ratio_increase, ratio_zero=args.ratio_zero)

    tmp_list = []
    set_seed(args)  # Added here for reproducibility (even between python 2 and 3)
    loss_rec_avg = []
    loss_clas_avg = []
    loss_reg_z_avg = []
    loss_c_avg = []
    loss_reg_z_c_avg = []
    temp = args.temperature
    lambda_clas = args.lambda_clas
    lambda_recon = args.lambda_recon
    for epoch in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            attributes_2 = None  # : for semi-supervised

            #tokenized_text0, tokenized_text1, tokenized_text_lengths = batch
            if args.semi:
                tokenized_text0, tokenized_text1, tokenized_text_lengths, attributes, attributes_2 = batch #
            else:
                tokenized_text0, tokenized_text1, tokenized_text_lengths, attributes = batch  # 
            #for t in tokenized_text0:
            #    print(encoder_tokenizer.decode(t.numpy()))
            #print('='*10)
            #for t in tokenized_text1:
            #    print(t.numpy())
            #    print(decoder_tokenizer.decode(t.numpy()))
            #print(attributes)
            #exit()

            # tokenized_text0 = tokenized_text0.to(args.device)
            # tokenized_text1 = tokenized_text1.to(args.device)
            # prepare input-output data for reconstruction

            # if (tokenized_text0>len(encoder_tokenizer)).sum().item()>0.0 or (tokenized_text1>len(decoder_tokenizer)).sum().item()>0.0:
            #     pdb.set_trace()
            #     continue

            assert args.mlm == False  # 
            inputs, labels = mask_tokens(tokenized_text0, encoder_tokenizer, args) if args.mlm else (tokenized_text0, tokenized_text1)
            labels = tokenized_text1

            tokenized_text1 = tokenized_text1.to(args.device)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            attributes = attributes.to(args.device)
            if args.semi:
                attributes_2 = attributes_2.to(args.device)

            model_vae.train()
            model_vae.classifier.eval()  # 
            for p in model_vae.classifier.parameters():
                p.requires_grad = False

            beta_t = beta_t_list[step + epoch*len(epoch_iterator)]
            try:
                model_vae.module.args.beta = beta_t #
            except:
                model_vae.args.beta = beta_t

            if beta_t == 0.0:
                try:
                    model_vae.module.args.fb_mode = 0
                except:
                    model_vae.args.fb_mode = 0
            else:
                try:
                    model_vae.module.args.fb_mode = 1
                except:
                    model_vae.args.fb_mode = 1

            if args.use_deterministic_connect:
                try:
                    model_vae.module.args.fb_mode = 2
                except:
                    model_vae.args.fb_mode = 2

            if global_step > 0 and global_step % args.temperature_anneal_iters == 0:
                temp = max(0.001, temp * args.temperature_anneal_factor)  # TODO

            loss_rec, loss_kl, loss_clas, loss_reg_z, loss_c, loss_reg_z_c, loss, max_length = model_vae(
                inputs, labels,
                attributes=attributes,
                attributes_2=attributes_2,
                lambda_recon=lambda_recon,
                lambda_clas=lambda_clas,
                lambda_reg_z=args.lambda_reg_z,
                lambda_c_loss=args.lambda_c_loss,
                lambda_reg_z_c=args.lambda_reg_z_c,
                temperature=temp,
                use_gumbel=args.use_gumbel,  # 
                tokenizer_decoder=decoder_tokenizer,
                tokenizer_encoder=encoder_tokenizer,
                hard=args.gumbel_hard,
                cond_a=args.cond_a,
                cond_c=args.cond_c)

            # 
            ## Chunyuan: loss_rec size is [4], while latent_z size is [12]
            #if args.n_gpu > 1:
            #    loss_rec = loss_rec.mean()  # mean() to average on multi-gpu parallel training
            #    loss_kl = loss_kl.mean()
            #    loss = loss.mean()
            loss_rec = loss_rec.mean()  # mean() to average on multi-gpu parallel training
            loss_kl = loss_kl.mean()
            loss_clas = loss_clas.mean()
            loss_reg_z = loss_reg_z.mean()
            loss_c = loss_c.mean()
            loss_reg_z_c = loss_reg_z_c.mean()
            loss = loss.mean()

            if args.use_philly:
                print("PROGRESS: {}%".format(round(100 * (step +  epoch*len(epoch_iterator) ) /(int(args.num_train_epochs) *  len(epoch_iterator)) , 4)))
                print("EVALERR: {}%".format(loss_rec))

            try:  # 
                beta_f = model_vae.module.args.beta
            except:
                beta_f = model_vae.args.beta

            loss_rec_avg.append(loss_rec.item())
            loss_rec_avg_ = np.mean(loss_rec_avg[-10:])
            loss_clas_avg.append(loss_clas.item())
            loss_clas_avg_ = np.mean(loss_clas_avg[-10:])
            loss_reg_z_avg.append(loss_reg_z.item())
            loss_reg_z_avg_ = np.mean(loss_reg_z_avg[-10:])
            loss_c_avg.append(loss_c.item())
            loss_c_avg_ = np.mean(loss_c_avg[-10:])
            loss_reg_z_c_avg.append(loss_reg_z_c.item())
            loss_reg_z_c_avg_ = np.mean(loss_reg_z_c_avg[-10:])
            description =  (
                f'iter: {step +  epoch*len(epoch_iterator) }; loss: {loss.item():.3f}; '
                f'loss_rec: {loss_rec_avg_:.3f}; loss_kl: {loss_kl.item():.3f}; '
                f'loss_clas: {loss_clas_avg_:.3f}; '
                f'loss_reg_z: {loss_reg_z_avg_:.6f}; '
                f'loss_c: {loss_c_avg_:.3f}; '
                f'loss_reg_z_c: {loss_reg_z_c_avg_:.6f}; '
                f'beta: {beta_f:.3f}; temperature: {temp:.4f}; max_length: {max_length:.1f}'
            )
            epoch_iterator.set_description(description)

            if global_step % 100 == 0:  # 
                print(description)

            # if global_step % 5 == 0:
            #     row = {
            #             'PartitionKey': 'MILU_Rule_Rule_Template',
            #             'RowKey': str(datetime.now()),
            #             'ExpName' : args.ExpName,
            #             'iter': str( step +  epoch*len(epoch_iterator) ),
            #             'loss': str( loss.item()),
            #             'loss_rec': str(loss_rec.item()),
            #             'loss_kl': str(loss_kl.item()),
            #             'beta': str(model_vae.args.beta)
            #         }
            #     # pdb.set_trace()
            #     ts.insert_entity(table_name, row)

            # pdb.set_trace()

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            tr_loss_clas += loss_clas.item()
            tr_loss_reg_z += loss_reg_z.item()
            tr_loss_c += loss_c.item()
            tr_loss_reg_z_c += loss_reg_z_c.item()
            tr_loss_recon += loss_rec.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model_vae.parameters(), args.max_grad_norm)

                optimizer.step()

                scheduler.step()  # Update learning rate schedule

                model_vae.zero_grad()

                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and (global_step % args.logging_steps == 0 or False): #global_step == 1):  # 
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model_vae, encoder_tokenizer, decoder_tokenizer, classifier_tokenizer, prefix=str(global_step))
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss)/min(args.logging_steps, global_step), global_step)
                    tb_writer.add_scalar('loss_clas', (tr_loss_clas - logging_loss_clas)/min(args.logging_steps, global_step), global_step)
                    tb_writer.add_scalar('loss_reg_z', (tr_loss_reg_z - logging_loss_reg_z)/min(args.logging_steps, global_step), global_step)
                    tb_writer.add_scalar('loss_c', (tr_loss_c - logging_loss_c)/min(args.logging_steps, global_step), global_step)
                    tb_writer.add_scalar('loss_reg_z_c', (tr_loss_reg_z_c - logging_loss_reg_z_c)/min(args.logging_steps, global_step), global_step)
                    tb_writer.add_scalar('loss_recon', (tr_loss_recon - logging_loss_recon)/min(args.logging_steps, global_step), global_step)
                    tb_writer.add_scalar('temperature', temp, global_step)  # 
                    tb_writer.add_scalar('lambda_clas', lambda_clas, global_step)
                    logging_loss = tr_loss
                    logging_loss_clas = tr_loss_clas
                    logging_loss_reg_z = tr_loss_reg_z
                    logging_loss_c = tr_loss_c
                    logging_loss_reg_z_c = tr_loss_reg_z_c
                    logging_loss_recon = tr_loss_recon

                    print('temperature: %f, lambda_clas: %f' % (temp, lambda_clas))

                    #evaluate_clas(args, model_vae.classifier, encoder_tokenizer, decoder_tokenizer, classifier_tokenizer)  # 

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    save_checkpoint(model_vae, optimizer, global_step, args)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break


        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step, optimizer


def train_p_az(args, train_dataloader, model_vae, encoder_tokenizer, decoder_tokenizer, classifier_tokenizer, table_name):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        #tb_writer = SummaryWriter()
        tb_writer = get_summarywriter(args.output_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    # train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    # train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)


    # model_encoder, model_decoder, model_connector = model_vae.encoder,  model_vae.decoder, model_vae.linear
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model_vae.a_layer_given_z.named_parameters()], 'weight_decay': 0.0}
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model_vae, optimizer = amp.initialize(model_vae, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model_vae = torch.nn.DataParallel(model_vae, device_ids=range(args.n_gpu)).to(args.device)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)

    # Train!
    logger.info("***** Running training p_az *****")
    logger.info("  Num examples = %d", train_dataloader.num_examples)
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model_vae.zero_grad()

    # evaluate_clas(args, model_vae.classifier, encoder_tokenizer, decoder_tokenizer, classifier_tokenizer)  # 

    # model_vae = model_vae.module if hasattr(model_vae, 'module') else model_vae  # Take care of distributed/parallel training

    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])

    n_iter = int(args.num_train_epochs) * len(train_dataloader)
    # 
    #beta_t_list = frange_cycle_zero_linear(n_iter, start=0.0, stop=args.beta,  n_cycle=1, ratio_increase=args.ratio_increase, ratio_zero=args.ratio_zero)
    beta_t_list = frange_cycle_zero_linear(n_iter, start=0.0, stop=args.beta,  n_cycle=int(args.num_train_epochs), ratio_increase=args.ratio_increase, ratio_zero=args.ratio_zero)

    set_seed(args)  # Added here for reproducibility (even between python 2 and 3)
    for epoch in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            attributes_2 = None  # : for semi-supervised

            #tokenized_text0, tokenized_text1, tokenized_text_lengths = batch
            if args.semi:
                tokenized_text0, tokenized_text1, tokenized_text_lengths, attributes, attributes_2 = batch #
            else:
                tokenized_text0, tokenized_text1, tokenized_text_lengths, attributes = batch  # 

            assert args.mlm == False  # 
            inputs, labels = mask_tokens(tokenized_text0, encoder_tokenizer, args) if args.mlm else (tokenized_text0, tokenized_text1)
            labels = tokenized_text1

            tokenized_text1 = tokenized_text1.to(args.device)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            attributes = attributes.to(args.device)
            if args.semi:
                attributes_2 = attributes_2.to(args.device)

            model_vae.train()
            model_vae.classifier.eval()
            for p in model_vae.classifier.parameters():
                p.requires_grad = False
            model_vae.encoder.eval()
            for p in model_vae.encoder.parameters():
                p.requires_grad = False
            model_vae.decoder.eval()
            for p in model_vae.decoder.parameters():
                p.requires_grad = False

            beta_t = beta_t_list[step + epoch*len(epoch_iterator)]
            try:
                model_vae.module.args.beta = beta_t #
            except:
                model_vae.args.beta = beta_t

            if beta_t == 0.0:
                try:
                    model_vae.module.args.fb_mode = 0
                except:
                    model_vae.args.fb_mode = 0
            else:
                try:
                    model_vae.module.args.fb_mode = 1
                except:
                    model_vae.args.fb_mode = 1

            if args.use_deterministic_connect:
                try:
                    model_vae.module.args.fb_mode = 2
                except:
                    model_vae.args.fb_mode = 2

            loss, _ = model_vae(
                inputs, labels,
                attributes=attributes,
                tokenizer_decoder=decoder_tokenizer,
                tokenizer_encoder=encoder_tokenizer,
                cond_a=args.cond_a,
                cond_c=args.cond_c,
                train_a_layer=True
            )

            try:  # 
                beta_f = model_vae.module.args.beta
            except:
                beta_f = model_vae.args.beta

            description =  (
                f'iter: {step +  epoch*len(epoch_iterator) }; loss: {loss.item():.3f}; '
            )
            epoch_iterator.set_description(description)

            if global_step % 100 == 0:  # 
                print(description)

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model_vae.parameters(), args.max_grad_norm)

                optimizer.step()

                scheduler.step()  # Update learning rate schedule

                model_vae.zero_grad()

                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and (global_step % args.logging_steps == 0 or global_step == 1):  # 
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate_a_layer(args, model_vae, encoder_tokenizer, decoder_tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss)/min(args.logging_steps, global_step), global_step)
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    save_checkpoint(model_vae, optimizer, global_step, args)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step, optimizer


def train_latent_gan(args, train_dataloader, model_vae, encoder_tokenizer, decoder_tokenizer, classifier_tokenizer, table_name, senti_classifier, categ_classifier):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        #tb_writer = SummaryWriter()
        tb_writer = get_summarywriter(args.output_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    # train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    # train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)


    # model_encoder, model_decoder, model_connector = model_vae.encoder,  model_vae.decoder, model_vae.linear
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model_vae.latent_generator.named_parameters()], 'weight_decay': 0.0},
        {'params': [p for n, p in model_vae.latent_discriminator.named_parameters()], 'weight_decay': 0.0}
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model_vae, optimizer = amp.initialize(model_vae, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model_vae = torch.nn.DataParallel(model_vae, device_ids=range(args.n_gpu)).to(args.device)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model_vae = torch.nn.parallel.DistributedDataParallel(model_vae, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)

    # Train!
    logger.info("***** Running training p_az *****")
    logger.info("  Num examples = %d", train_dataloader.num_examples)
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    tr_loss_lsd, logging_loss_lsd = 0.0, 0.0
    tr_loss_lsg, logging_loss_lsg = 0.0, 0.0
    tr_accu_enc, logging_accu_enc = 0.0, 0.0
    tr_accu_g, logging_accu_g = 0.0, 0.0
    model_vae.zero_grad()

    # evaluate_clas(args, model_vae.classifier, encoder_tokenizer, decoder_tokenizer, classifier_tokenizer)  # 

    # model_vae = model_vae.module if hasattr(model_vae, 'module') else model_vae  # Take care of distributed/parallel training

    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])

    n_iter = int(args.num_train_epochs) * len(train_dataloader)
    # 
    #beta_t_list = frange_cycle_zero_linear(n_iter, start=0.0, stop=args.beta,  n_cycle=1, ratio_increase=args.ratio_increase, ratio_zero=args.ratio_zero)
    beta_t_list = frange_cycle_zero_linear(n_iter, start=0.0, stop=args.beta,  n_cycle=int(args.num_train_epochs), ratio_increase=args.ratio_increase, ratio_zero=args.ratio_zero)

    set_seed(args)  # Added here for reproducibility (even between python 2 and 3)
    for epoch in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            attributes_2 = None  # : for semi-supervised

            #tokenized_text0, tokenized_text1, tokenized_text_lengths = batch
            if args.semi:
                tokenized_text0, tokenized_text1, tokenized_text_lengths, attributes, attributes_2 = batch #
            else:
                tokenized_text0, tokenized_text1, tokenized_text_lengths, attributes = batch  # 

            assert args.mlm == False  # 
            inputs, labels = mask_tokens(tokenized_text0, encoder_tokenizer, args) if args.mlm else (tokenized_text0, tokenized_text1)
            labels = tokenized_text1

            tokenized_text1 = tokenized_text1.to(args.device)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            attributes = attributes.to(args.device)
            if args.semi:
                attributes_2 = attributes_2.to(args.device)

            model_vae.train()
            model_vae.classifier.eval()
            for p in model_vae.classifier.parameters():
                p.requires_grad = False
            model_vae.encoder.eval()
            for p in model_vae.encoder.parameters():
                p.requires_grad = False
            model_vae.decoder.eval()
            for p in model_vae.decoder.parameters():
                p.requires_grad = False

            beta_t = beta_t_list[step + epoch*len(epoch_iterator)]
            try:
                model_vae.module.args.beta = beta_t #
            except:
                model_vae.args.beta = beta_t

            if beta_t == 0.0:
                try:
                    model_vae.module.args.fb_mode = 0
                except:
                    model_vae.args.fb_mode = 0
            else:
                try:
                    model_vae.module.args.fb_mode = 1
                except:
                    model_vae.args.fb_mode = 1

            if args.use_deterministic_connect:
                try:
                    model_vae.module.args.fb_mode = 2
                except:
                    model_vae.args.fb_mode = 2

            loss, loss_lsd, loss_lsg, acc_encode_z_dis, acc_gen_z_dis = model_vae(
                inputs, labels,
                attributes=attributes,
                tokenizer_decoder=decoder_tokenizer,
                tokenizer_encoder=encoder_tokenizer,
                cond_a=args.cond_a,
                cond_c=args.cond_c,
                train_gan=True
            )

            try:  # 
                beta_f = model_vae.module.args.beta
            except:
                beta_f = model_vae.args.beta

            description =  (
                f'iter: {step +  epoch*len(epoch_iterator) }; loss: {loss.item():.3f}; '
                f'loss_d: {loss_lsd.item():.3f}; '
                f'loss_g: {loss_lsg.item():.3f}; '
                f'acc_enc: {acc_encode_z_dis.item():.3f}; '
                f'acc_g: {acc_gen_z_dis.item():.3f}; '
            )
            epoch_iterator.set_description(description)

            if global_step % 100 == 0:  # 
                print(description)

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            tr_loss_lsd += loss_lsd.item()
            tr_loss_lsg += loss_lsg.item()
            tr_accu_enc += acc_encode_z_dis.item()
            tr_accu_g += acc_gen_z_dis.item()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model_vae.parameters(), args.max_grad_norm)

                optimizer.step()

                scheduler.step()  # Update learning rate schedule

                model_vae.zero_grad()

                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and (global_step % args.logging_steps == 0):  # 
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate_gan(args, model_vae, encoder_tokenizer, decoder_tokenizer, classifier_tokenizer,
                                               senti_classifier, categ_classifier, output_dir=os.path.join(args.output_dir, 'gan'), step=global_step)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss)/min(args.logging_steps, global_step), global_step)
                    tb_writer.add_scalar('loss_d', (tr_loss_lsd - logging_loss_lsd)/min(args.logging_steps, global_step), global_step)
                    tb_writer.add_scalar('loss_g', (tr_loss_lsg - logging_loss_lsg)/min(args.logging_steps, global_step), global_step)
                    tb_writer.add_scalar('accu_enc', (tr_accu_enc - logging_accu_enc)/min(args.logging_steps, global_step), global_step)
                    tb_writer.add_scalar('accu_g', (tr_accu_g - logging_accu_g)/min(args.logging_steps, global_step), global_step)
                    logging_loss = tr_loss
                    logging_loss_lsd = tr_loss_lsd
                    logging_loss_lsg = tr_loss_lsg
                    logging_accu_enc = tr_accu_enc
                    logging_accu_g = tr_accu_g

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    save_checkpoint(model_vae, optimizer, global_step, args)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step, optimizer



def _get_clas_accuracy(classifier, text, labels, classifier_tokenizer):
    classifier.eval()
    with torch.no_grad():
        outputs = classifier(
            input_ids=text,
            end_token_id_or_embeds=classifier_tokenizer.eos_token_id,
            labels=labels
        )
        tmp_output, pooled_fea = outputs
        tmp_eval_loss, logits = tmp_output

        eval_loss = tmp_eval_loss.mean().item()

        logits = logits.detach().cpu().numpy()
        preds = np.argmax(logits, axis=1).tolist()
        out_label_ids = labels.detach().cpu().numpy().tolist()

    return eval_loss, preds, out_label_ids, logits.tolist()


def evaluate_gan(args, model_vae, encoder_tokenizer, decoder_tokenizer, classifier_tokenizer,
                 senti_classifier=None, categ_classifier=None, output_dir=None, step=None):
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = output_dir

    if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(eval_output_dir)

    #args.per_gpu_eval_batch_size = 1 #
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)

    model_vae.eval()
    model_vae = model_vae.module if hasattr(model_vae, 'module') else model_vae  # Take care of distributed/parallel training

    ## Generate from prior
    prior_gen_fn = evaluate_generation_fromp_prior(model_vae, decoder_tokenizer, args, path=eval_output_dir, step=step, gan=True)

    eval_dataset = load_and_cache_examples(args, decoder_tokenizer, evaluate=True, path=prior_gen_fn)
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    logger.info("***** Running evaluation -- prior generation *****")
    logger.info("  Num examples = %d", len(eval_dataloader))
    logger.info("  Batch size = %d", args.eval_batch_size)

    senti_preds = []
    senti_out_label_ids = []
    senti_logits = []
    senti_eval_loss = 0
    categ_preds = []
    categ_out_label_ids = []
    categ_logits = []
    categ_eval_loss = 0
    nb_eval_steps = 0
    for batch in tqdm(eval_dataloader, desc="Evaluating GEN ACCU"):
        batch = tuple(t.to(args.device) for t in batch)
        x, x_lengths, attributes = batch

        x = x[:, 1:] # remove BOS
        sample_mask = (x != 50257).long() * (x != 50259).long() # 50257 is the padding token for GPT2
        x = _mask_with_eos(x, sample_mask, classifier_tokenizer.eos_token_id)
        x.to(args.device)

        # sentiment
        tmp_eval_loss, tmp_preds, tmp_out_label_ids, tmp_logits = _get_clas_accuracy(
            senti_classifier, x, _get_sentiment_attributes(attributes, args.multi_attribute), classifier_tokenizer
        )

        senti_eval_loss += tmp_eval_loss
        senti_preds += tmp_preds
        senti_out_label_ids += tmp_out_label_ids
        senti_logits += tmp_logits

        # category
        tmp_eval_loss, tmp_preds, tmp_out_label_ids, tmp_logits = _get_clas_accuracy(
            categ_classifier, x, _get_sentiment_attributes(attributes, args.multi_attribute), classifier_tokenizer
        )
        categ_eval_loss += tmp_eval_loss
        categ_preds += tmp_preds
        categ_out_label_ids += tmp_out_label_ids
        categ_logits += tmp_logits

        nb_eval_steps += 1

    senti_eval_loss = senti_eval_loss / nb_eval_steps
    senti_out_label_ids = [i[0] for i in senti_out_label_ids]
    senti_accu = (np.asarray(senti_preds) == np.asarray(senti_out_label_ids)).mean()

    categ_eval_loss = categ_eval_loss / nb_eval_steps
    categ_out_label_ids = [i[0] for i in categ_out_label_ids]
    categ_accu = (np.asarray(categ_preds) == np.asarray(categ_out_label_ids)).mean()

    result = {
        "senti_accuracy": senti_accu,
        "senti_clas_loss": senti_eval_loss,
        "categ_accuracy": categ_accu,
        "categ_clas_loss": categ_eval_loss,
    }
    logger.info("***** Eval results *****")
    for key in sorted(result.keys()):
        logger.info("  %s = %s", key, str(result[key]))

    return result


def evaluate_a_layer(args, model_vae, encoder_tokenizer, decoder_tokenizer):
    """
    p(a|z)
    """
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    eval_dataloader = build_dataload_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True)

    # Eval!
    logger.info("***** Running a-layer evaluation *****")
    logger.info("  Num examples = %d", len(eval_dataloader))
    logger.info("  Batch size = %d", args.eval_batch_size)

    if args.use_deterministic_connect:
        try:
            model_vae.module.args.fb_mode = 2
        except:
            model_vae.args.fb_mode = 2

    eval_loss = 0.0
    nb_eval_steps = 0
    preds = []
    out_label_ids = []

    latent_features = []
    latent_labels = []

    model_vae.eval()
    model_vae.classifier.eval()
    model_vae.encoder.eval()
    model_vae.decoder.eval()

    nb = 0
    eval_loss = 0.0
    nb_eval_steps = 0
    preds = []
    out_label_ids = []
    for batch in tqdm(eval_dataloader, desc="Evaluating a-layer"):
        nb += 1
        batch = tuple(t.to(args.device) for t in batch)
        tokenized_text0, tokenized_text1, _, attributes = batch

        temp_loss, temp_logits = model_vae(
            tokenized_text0, tokenized_text1,
            attributes=attributes,
            tokenizer_decoder=decoder_tokenizer,
            tokenizer_encoder=encoder_tokenizer,
            cond_a=args.cond_a,
            cond_c=args.cond_c,
            train_a_layer=True
        )

        eval_loss += temp_loss.mean().item()

        temp_logits = temp_logits.detach().cpu().numpy()
        temp_preds = np.argmax(temp_logits, axis=1).tolist()
        preds += temp_preds

        out_label_ids += _get_sentiment_attributes(attributes, args.multi_attribute).detach().cpu().numpy().tolist()

    accu = (np.asarray(preds) == np.asarray(out_label_ids)).mean()

    result = {
        "a_z_accuracy": accu,
        "a_z_loss": eval_loss / nb
    }
    logger.info("***** Eval results *****")
    for key in sorted(result.keys()):
        logger.info("  %s = %s", key, str(result[key]))

    return result


def evaluate_clas(args, classifier, encoder_tokenizer, decoder_tokenizer, classifier_tokenizer, n_batches=-1):

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)

    eval_dataloader = build_dataload_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True)

    # Eval!
    logger.info("***** Running classifier evaluation *****")
    logger.info("  Num examples = %d", len(eval_dataloader))
    logger.info("  Batch size = %d", args.eval_batch_size)

    eval_loss = 0.0
    nb_eval_steps = 0
    preds = []
    out_label_ids = []

    latent_features = []
    latent_labels = []

    nb = 0
    for batch in tqdm(eval_dataloader, desc="Evaluating classifier"):
        nb += 1
        if n_batches > 0 and nb > n_batches:
            break

        batch = tuple(t.to(args.device) for t in batch)
        _, tokenized_text1, tokenized_text_lengths, labels = batch

        tmp_eval_loss, tmp_preds, tmp_out_label_ids = _get_clas_accuracy(
            classifier, tokenized_text1, labels, classifier_tokenizer
        )

        eval_loss += tmp_eval_loss
        preds += tmp_preds
        out_label_ids += tmp_out_label_ids

        nb_eval_steps += 1

        #classifier.eval()

        #with torch.no_grad():
        #    #for t in tokenized_text1.cpu():
        #    #    print(t.numpy())
        #    #    print(decoder_tokenizer.decode(t.numpy()[1:]))

        #    outputs = classifier(
        #        input_ids=tokenized_text1,
        #        end_token_id_or_embeds=classifier_tokenizer.eos_token_id,
        #        labels=labels
        #    )
        #    tmp_output, pooled_fea = outputs
        #    tmp_eval_loss, logits = tmp_output

        #    eval_loss += tmp_eval_loss.mean().item()
        #if preds is None:
        #    preds = logits.detach().cpu().numpy()
        #    out_label_ids = labels.detach().cpu().numpy()
        #else:
        #    preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
        #    out_label_ids = np.append(out_label_ids, labels.detach().cpu().numpy(), axis=0)

    eval_loss = eval_loss / nb_eval_steps
    #preds = np.argmax(preds, axis=1)
    result = (np.asarray(preds) == np.asarray(out_label_ids)).mean()

    print('Classifier accuracy: ', result, '; loss: ', eval_loss)

    return result


def _get_reverse_attributes(a, multi=False):
    if multi:
        # 0 -> 2
        # 1 -> 3
        # 2 -> 0
        # 3 -> 1
        return (a + 2) % 4
    else:
        return 1 - a


def _get_sentiment_attributes(a, multi=False):
    if multi:
        # 0 -> 0
        # 1 -> 0
        # 2 -> 1
        # 3 -> 1
        return (a > 1).long()
    else:
        return a


def _get_category_attributes(a, multi=False):
    if multi:
        # 0 -> 0
        # 1 -> 1
        # 2 -> 0
        # 3 -> 1
        return (a % 2).long()
    else:
        raise NotImplementedError


def _add_attributes(z, a, model_vae, concat=True):
    """
    z.shape == [batch_size, dim] or [batch_size, nsamples, dim]
    a.shape == [batch_size]
    """
    batch_size = a.shape[0]

    a_fea = model_vae.attr_layer(a.unsqueeze(-1).float())  # [batch_size, self.nattr]

    if z.dim() == 3:
        a_fea = a_fea.view([batch_size, 1, -1]).repeat(1, z.shape[1], 1)

    if concat:
        z_new = torch.cat( (z, a_fea), -1)
    else:
        z_new = torch.cat( (z[:,:-model_vae.nattr], a_fea), -1)
    return z_new

    #batch_size = a.shape[0]

    #if z.dim() == 2:
    #    a = a.unsqueeze(-1)
    #elif z.dim() == 3:
    #    a = a.view([batch_size,1,1]).repeat(1,z.shape[1],1)
    #else:
    #    raise NotImplementedError

    #if concat:
    #    z_new = torch.cat( (z, a), -1)
    #else:
    #    z_new = torch.cat( (z[:,:-1], a), -1)
    #return z_new


def _mask_with_eos(text_ids, masks, eos_token_id):
    return text_ids * masks + (1 - masks) * eos_token_id


def evaluate(args, model_vae, encoder_tokenizer, decoder_tokenizer, classifier_tokenizer, table_name=None, prefix="", subset="test", classifier=None):
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = args.output_dir

    # if subset == 'test':
    #     eval_dataset = load_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True)
    # elif subset == 'train':
    #     eval_dataset = load_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=False)
    logger.info("***** Running evaluation on {} dataset *****".format(subset))

    if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(eval_output_dir)

    #args.per_gpu_eval_batch_size = 1 #
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)

    # Note that DistributedSampler samples randomly
    # eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
    # eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    eval_dataloader = build_dataload_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataloader))
    logger.info("  Batch size = %d", args.eval_batch_size)

    model_vae.eval()

    model_vae = model_vae.module if hasattr(model_vae, 'module') else model_vae  # Take care of distributed/parallel training

    #mi = calc_mi(model_vae, eval_dataloader, args)
    #au = calc_au(model_vae, eval_dataloader, delta=0.01, args=args)[0]
    mi = 0  # 
    au = 0  # 
    if args.eval_elbo:
        ppl, elbo, nll, kl = calc_iwnll(model_vae, eval_dataloader, args, ns=100)
    else:
        ppl, elbo, nll, kl = 0, 0, 0, 0  # 


    ## : accuracy
    #gen = defaultdict(str)
    #gen_labels = defaultdict(str)
    gen = []
    gen_labels = []
    preds = []
    out_label_ids = []
    logits = []
    eval_loss = 0
    nb_eval_steps = 0
    refs = []
    hyps = []
    hyps_ori = []
    eval_dataloader = build_dataload_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True)
    for batch in tqdm(eval_dataloader, desc="Evaluating GEN ACCU"):
        batch = tuple(t.to(args.device) for t in batch)
        x0, x1, x_lengths, attributes = batch

        context_tokens = decoder_tokenizer.encode('<BOS>')

        with torch.no_grad():

            if isinstance(model_vae.encoder, BertForLatentConnector):
                pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]
            elif isinstance(model_vae.encoder, GPT2Encoder):
                encoder_input = x0
                if args.multi_attribute:
                    if args.cond_a and args.cond_c:
                        raise NotImplementedError  # 
                    elif args.cond_a:
                        encoder_input = model_vae._append_a(x0, _get_sentiment_attributes(attributes, args.multi_attribute))
                    elif args.cond_c:
                        encoder_input = model_vae._append_a(x0, _get_category_attributes(attributes, args.multi_attribute))
                else:
                    if args.cond_c:
                        raise NotImplementedError
                    if args.cond_a:
                        encoder_input = model_vae._append_a(x0, attributes)

                pooled_hidden_fea = model_vae.encoder(encoder_input,
                                                      #attention_mask=(x0 > 0).float(),
                                                      end_token_id_or_embeds=encoder_tokenizer.eos_token_id)
            else:
                raise NotImplementedError


            # Connect hidden feature to the latent space
            # latent_z, loss_kl = model_vae.connect(pooled_hidden_fea)
            mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
            latent_z = mean.squeeze(1)

            latent_z_new = _add_attributes(
                latent_z,
                _get_sentiment_attributes(attributes, args.multi_attribute),
                model_vae, concat=False)

            past = latent_z_new

            out, out_mask = sample_sequence_conditional_batch(
                model=model_vae.decoder,
                context=context_tokens[0] * torch.ones(x0.shape[0], 1, dtype=torch.long),
                past=past,
                length=x_lengths[:,1].max(),
                #temperature=args.temperature,
                #top_k=args.top_k,
                #top_p=args.top_p,
                device=args.device,
                decoder_tokenizer=decoder_tokenizer
            )
            out = _mask_with_eos(out, out_mask, classifier_tokenizer.eos_token_id)

            ## flip attributes, 
            #latent_z_new = _add_attributes(latent_z, _get_reverse_attributes(attributes, args.multi_attribute), model_vae, concat=False)
            latent_z_new = _add_attributes(latent_z, 1-_get_sentiment_attributes(attributes, args.multi_attribute), model_vae, concat=False)

            past = latent_z_new
            out_2, out_2_mask = sample_sequence_conditional_batch(
                model=model_vae.decoder,
                context=context_tokens[0] * torch.ones(x0.shape[0], 1, dtype=torch.long),
                past=past,
                length=x_lengths[:,1].max(),
                #temperature=args.temperature,
                #top_k=args.top_k,
                #top_p=args.top_p,
                device=args.device,
                decoder_tokenizer = decoder_tokenizer
            )
            out_2 = _mask_with_eos(out_2, out_2_mask, classifier_tokenizer.eos_token_id)


            #print('*' * 10 + ' Evaluate ' + '*' * 10)
            #print('out_2')
            #print(out_2)
            #for b in range(8):
            #    text_x1_data = decoder_tokenizer.decode(out_2[b].tolist(), clean_up_tokenization_spaces=False)
            #    text_x1_data = text_x1_data.split('<EOS>')[0].strip()
            #    print(text_x1_data)
            #print('*' * 30)


            # 
            classifier = model_vae.classifier if classifier is None else classifier

            if args.eval_self_accuracy:
                tmp_eval_loss, tmp_preds, tmp_out_label_ids, tmp_logits = _get_clas_accuracy(
                    classifier, out, _get_sentiment_attributes(attributes, args.multi_attribute), classifier_tokenizer
                )
            else:
                tmp_eval_loss, tmp_preds, tmp_out_label_ids, tmp_logits = _get_clas_accuracy(
                    classifier, out_2, 1 - _get_sentiment_attributes(attributes, args.multi_attribute), classifier_tokenizer
                )

            #print('input_ids: ' + '*' * 20)
            #print(out_2)
            #print(classifier_tokenizer.convert_ids_to_tokens([i for i in out_2[0].tolist()]))
            #print(decoder_tokenizer.convert_ids_to_tokens([i for i in out_2[0].tolist()]))
            #print(decoder_tokenizer.decode(out_2[0].tolist()))
            #print('*' * 50)

            ##x1 = torch.cat([x1, torch.ones_like(x1) * classifier_tokenizer.eos_token_id], dim=-1)
            #x1_mask = (x1 != decoder_tokenizer.pad_token_id).int()
            #x1 = _mask_with_eos(x1, x1_mask, decoder_tokenizer.eos_token_id)
            #print(x1[:, 1:])
            #tmp_eval_loss, tmp_preds, tmp_out_label_ids, tmp_logits = _get_clas_accuracy(
            #    classifier, x1[:, 1:], attributes, classifier_tokenizer
            #)

            #tmp_eval_loss, tmp_preds, tmp_out_label_ids = 0, [0], [0]
            eval_loss += tmp_eval_loss
            preds += tmp_preds
            out_label_ids += tmp_out_label_ids
            logits += tmp_logits

            nb_eval_steps += 1

            for b in range(x0.shape[0]):
                # 
                # clean_up_tokenization_spaces=True
                text_x0 = encoder_tokenizer.decode(x0[b,:x_lengths[b,0]].tolist(), clean_up_tokenization_spaces=False)[0]

                text_x1_data = decoder_tokenizer.decode(x1[b,1:].tolist(), clean_up_tokenization_spaces=False)
                text_x1_data = text_x1_data.split('<EOS>')[0].strip()
                #gen[text_x0] = [text_x1_data]
                refs.append([text_x1_data.strip()])

                text_x1 = decoder_tokenizer.decode(out[b,:].tolist(), clean_up_tokenization_spaces=False)
                #text_x1 = text_x1.split()[1:-1]
                #text_x1 = text_x1.split()[:-1]
                #text_x1 = ' '.join(text_x1)
                text_x1 = text_x1.split('<EOS>')[0].strip()
                #gen[text_x0] = [text_x1]
                #gen[text_x1_data] = [text_x1]
                #gen_labels[text_x1_data] = [attributes[b].item()]
                hyps_ori.append(text_x1)

                text_x1_tst = decoder_tokenizer.decode(out_2[b,:].tolist(), clean_up_tokenization_spaces=False)
                text_x1_tst = text_x1_tst.split('<EOS>')[0].strip()
                #gen[text_x1_data].append(text_x1_tst)
                #gen_labels[text_x1_data].append(1 - attributes[b].item())

                gen.append([text_x1_data.strip(), text_x1.strip(), text_x1_tst.strip()])
                gen_labels.append([attributes[b].item(), attributes[b].item(), _get_reverse_attributes(attributes, args.multi_attribute)[b].item()])

                #refs.append([text_x0.strip()])
                hyps.append(text_x1_tst)

        #break

    eval_loss = eval_loss / nb_eval_steps
    accu = (np.asarray(preds) == np.asarray(out_label_ids)).mean()

    gen_file_name = "eval_recontruction_results_%s.txt" % prefix
    output_gen_file = os.path.join(eval_output_dir, gen_file_name)
    with open(output_gen_file, "w") as writer:
        pn = 5
        #for key in sorted(gen.keys()):
        #    if isinstance(gen[key], list):
        #        assert len(gen[key]) == 2
        #        if pn > 0:
        #            logger.info("  %s \n %s \n %s\n", key, str(gen[key][0]), str(gen[key][1]))
        #        writer.write("%s\n%s\n%s\n\n" % (key, str(gen[key][0]), str(gen[key][1])))
        #    else:
        #        if pn > 0:
        #            logger.info("  %s \n %s\n", key, str(gen[key]))
        #        writer.write("%s\n%s\n\n" % (key, str(gen[key])))
        #    pn -= 1
        for g in gen:
            assert len(g) == 3
            if pn > 0:
                logger.info("  %s \n %s \n %s\n", g[0], g[1], g[2])
            writer.write("%s\n%s\n%s\n\n" % (g[0], g[1], g[2]))
            pn -= 1

    assert len(preds) == len(gen_labels)

    gen_file_name = "eval_attributes_results_%s.txt" % prefix
    output_gen_file = os.path.join(eval_output_dir, gen_file_name)
    with open(output_gen_file, "w") as writer:
        pn = 5
    #    for key in sorted(gen.keys()):
    #        if isinstance(gen[key], list):
    #            if pn > 0:
    #                logger.info("%s\t%s", str(gen_labels[key][-1]), str(gen[key][-1]))
    #            writer.write("%s\t%s\n" % (str(gen_labels[key][-1]), str(gen[key][-1])))
    #        else:
    #            raise NotImplementedError
    #        pn -= 1
        #i = 0
        #for g, gl in zip(gen, gen_labels):
        #    assert len(g) == 3
        #    if pn > 0:
        #        logger.info("%d\t%s\n", gl[-1], g[-1])
        #    #writer.write("%d\t%d\t%s\n" % (preds[i], gl[-1], g[-1]))
        #    writer.write("%d\t%d\t%s\n" % (preds[i], out_label_ids[i], gen[i][0]))
        #    pn -= 1
        #    i += 1

        #l = list(zip([g[0] for g in gen], preds, out_label_ids, logits))
        l = list(zip([g[-1] for g in gen], preds, out_label_ids, logits))
        l.sort()
        text, preds, out_label_ids, logits = zip(*l)
        for p, g, t, lo in zip(preds, out_label_ids, text, logits):
            writer.write('%d\t%.5f\t%.5f\t%d\t%s\n' % (p, lo[0], lo[1], g, t))

    ## : BLEU
    bleu_ori = tx.evals.corpus_bleu_moses(refs, hyps_ori)
    bleu = tx.evals.corpus_bleu_moses(refs, hyps)


    result = {
        "perplexity": ppl,
        "elbo": elbo,
        "kl": kl,
        "nll": nll,
        #"au": au,
        #"mi": mi,
        "accuracy": accu, "clas_loss": eval_loss, "bleu": bleu, "bleu_ori": bleu_ori
    }

    output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results *****")
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

    #row = {
    #        'PartitionKey': 'MILU_Rule_Rule_Template',
    #        'RowKey': str(datetime.now()),
    #        'ExpName' : args.ExpName,
    #        'test_perplexity': str( ppl ),
    #        'test_elbo': str( elbo ),
    #        'test_nll': str(nll),
    #        'test_au': str(au),
    #        'test_mi': str(mi)
    #    }
    #pdb.set_trace()
    #ts.insert_entity(table_name, row)

    return result


def _compute_p_z_given_x(model_vae, z, x):
    if isinstance(model_vae.encoder, GPT2Encoder):
        outputs = model_vae.encoder(
            x,
            #attention_mask=attention_mask,
            end_token_id_or_embeds=model_vae.eos_token_id)
        assert model_vae.eos_token_id == 50259
        pooled_hidden_fea = outputs
        # cond_a not implemented
    else:
        raise NotImplementedError

    mu, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
    mu, logvar = mu[:, :-model_vae.nattr], logvar[:, :-model_vae.nattr]
    std = logvar.mul(0.5).exp()

    ## evaluate Gaussian
    #loss_fn = torch.nn.GaussianNLLLoss(full=True, reduction='none')
    #probs = loss_fn(input=mu, target=)
    pz_x = torch.distributions.normal.Normal(mu, std)
    log_probs = pz_x.log_prob(z).sum(dim=-1)

    return log_probs  # [batch_size]


def sampling_importance_resampling(model_vae, decoder_tokenizer, sample_fn, args):
    if len(sample_fn) == 0:
        sample_fn = evaluate_generation_fromp_prior(model_vae, decoder_tokenizer, args)

    print('Reading samples from ', sample_fn)

    #samples = {0: [], 1: []}
    #with open(sample_fn, 'r') as fin:
    #    for line in fin:
    #        label, text = line.strip().split('\t')
    #        label = int(label)
    #        samples[label].append(text.strip())

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    eval_dataset = load_and_cache_examples(args, decoder_tokenizer, evaluate=True, path=sample_fn)
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    logger.info("***** Running importance resampling *****")
    logger.info("  Num examples = %d", len(eval_dataloader))
    logger.info("  Batch size = %d", args.eval_batch_size)

    # sample z
    loc = torch.zeros([args.latent_size - model_vae.nattr]).to(args.device)
    scale = torch.ones([args.latent_size - model_vae.nattr]).to(args.device)
    prior = torch.distributions.normal.Normal(loc, scale)

    z_set = []
    z_log_probs = []
    for _ in range(args.num_z):
        latent_z = prior.sample(sample_shape=[1])
        log_prob_prior_z = prior.log_prob(latent_z).sum(dim=-1)
        latent_z = latent_z.repeat([args.eval_batch_size, 1])
        z_set.append(latent_z)
        z_log_probs.append(log_prob_prior_z)

    ii = 0
    log_probs_all_data = None
    for batch in tqdm(eval_dataloader, desc="Evaluating Importance Weight"):

        ii += 1
        if ii== 10:
            break

        batch = tuple(t.to(args.device) for t in batch)
        x, x_lengths, attributes = batch
        x = x[:, 1:]  # remove BOS
        sample_mask = (x != 50257).long() * (x != 50259).long()  # 50257 is the padding token for GPT2
        x = _mask_with_eos(x, sample_mask, model_vae.eos_token_id)
        x.to(args.device)

        log_probs_batch = None
        for z, z_log_prob in zip(z_set, z_log_probs):
            cur_log_probs = _compute_p_z_given_x(model_vae, z, x) - z_log_prob
            cur_log_probs = cur_log_probs.detach().view(-1, 1)
            if log_probs_batch is None:
                log_probs_batch = cur_log_probs
            else:
                log_probs_batch = torch.cat([log_probs_batch, cur_log_probs], dim=-1)

        if log_probs_all_data is None:
            log_probs_all_data = log_probs_batch
        else:
            log_probs_all_data = torch.cat([log_probs_all_data, log_probs_batch], dim=0)

    #assert log_probs_all_data.shape[0] == len(eval_dataloader)
    probs_all_data = torch.nn.Softmax()(log_probs_all_data.view(-1)).view(log_probs_all_data.shape)
    probs_all_data = probs_all_data.sum(-1)
    print(probs_all_data)
    exit()
    categorical_sampler = torch.distributions.categorical.Categorical(probs=probs_all_data)
    resamples = categorical_sampler.sample(sample_shape=args.num_resamples)
    print("resamples")
    print(resamples.shape)
    print(resamples)
    resamples = resamples.detach().cpu().numpy().tolist()

    out_fn = sample_fn + '.resamples'
    with open(sample_fn, 'r') as fin, open(out_fn, 'w') as fout:
        print('Output resamples to ', out_fn)
        data = []
        for line in fin:
            data.append(line.strip())
        data = np.asarray(data)
        resample_lines = data[resamples]
        for line in resample_lines:
            fout.write(line + '\n')


def sampling_importance_resampling_new(model_vae, decoder_tokenizer, sample_fn, args):
    if len(sample_fn) == 0:
        sample_fn = evaluate_generation_fromp_prior(model_vae, decoder_tokenizer, args)

    print('Reading samples from ', sample_fn)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    eval_dataset = load_and_cache_examples(args, decoder_tokenizer, evaluate=True, path=sample_fn)
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    logger.info("***** Running importance resampling *****")
    logger.info("  Num examples = %d", len(eval_dataloader))
    logger.info("  Batch size = %d", args.eval_batch_size)

    probs_all_data = None
    rev_probs_all_data = None
    for batch in tqdm(eval_dataloader, desc="Evaluating Importance Weight"):
        batch = tuple(t.to(args.device) for t in batch)
        x, x_lengths, attributes = batch
        senti_attributes = _get_sentiment_attributes(attributes, args.multi_attribute)

        #x = x[:, 1:]  # remove BOS
        #sample_mask = (x != 50257).long() * (x != 50259).long()  # 50257 is the padding token for GPT2
        #x = _mask_with_eos(x, sample_mask, model_vae.eos_token_id)
        #x.to(args.device)

        if isinstance(model_vae.encoder, BertForLatentConnector):
            pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]
        elif isinstance(model_vae.encoder, GPT2Encoder):
            encoder_input = x
            if args.multi_attribute:
                if args.cond_a and args.cond_c:
                    raise NotImplementedError
                elif args.cond_a:
                    encoder_input = model_vae._append_a(x, senti_attributes)
                elif args.cond_c:
                    encoder_input = model_vae._append_a(x, _get_category_attributes(attributes, args.multi_attribute))
            else:
                if args.cond_c:
                    raise NotImplementedError
                if args.cond_a:
                    encoder_input = model_vae._append_a(x, senti_attributes)

            pooled_hidden_fea = model_vae.encoder(encoder_input,
                                                  end_token_id_or_embeds=decoder_tokenizer.eos_token_id)
        else:
            raise NotImplementedError

        # Connect hidden feature to the latent space
        mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
        mean = mean[:, :-model_vae.nattr]
        logvar = logvar[:, :-model_vae.nattr]
        if args.num_z > 1:
            latent_z = model_vae.reparameterize(mean, logvar, nsamples=args.num_z)
        else:
            latent_z = mean.unsqueeze(1)
        logits = model_vae.a_layer_given_z(latent_z)
        senti_attributes = senti_attributes.unsqueeze(1).repeat([1, args.num_z, 1])
        probs_batch = torch.gather(torch.softmax(logits / args.temperature, dim=-1), dim=-1, index=senti_attributes)
        probs_batch = probs_batch.mean(dim=1).detach()

        # : trimming
        probs_batch = probs_batch + (probs_batch < args.trim_low).float() * 1e6 + (probs_batch > args.trim_high).float() * 1e6

        rev_probs_batch = 1. / (probs_batch + 1e-8)
        rev_probs_batch = rev_probs_batch.mean(dim=1)
        rev_probs_batch = rev_probs_batch.detach()

        if probs_all_data is None:
            probs_all_data = probs_batch
        else:
            probs_all_data = torch.cat([probs_all_data, probs_batch], dim=0)
        if rev_probs_all_data is None:
            rev_probs_all_data = rev_probs_batch
        else:
            rev_probs_all_data = torch.cat([rev_probs_all_data, rev_probs_batch], dim=0)

    probs_all_data = probs_all_data.view(-1)
    rev_probs_all_data = rev_probs_all_data.view(-1)
    rev_probs_all_data = rev_probs_all_data / rev_probs_all_data.sum()
    categorical_sampler = torch.distributions.categorical.Categorical(probs=rev_probs_all_data)
    resamples = categorical_sampler.sample(sample_shape=[args.num_resamples])
    resamples = resamples.detach().cpu().numpy().tolist()

    out_fn = sample_fn + '.resamples'
    with open(sample_fn, 'r') as fin, open(out_fn, 'w') as fout:
        print('Output resamples to ', out_fn)
        data = []
        for line in fin:
            data.append(line.strip())
        data = np.asarray(data)
        resample_lines = data[resamples]
        for line in resample_lines:
            fout.write(line + '\n')

    probs_all_data = probs_all_data.detach().cpu().numpy().tolist()
    with open(sample_fn + '.weight', 'w') as fout:
        for p in probs_all_data:
            fout.write(str(p) + '\n')


def evaluate_generation_fromp_prior(model_vae, decoder_tokenizer, args, path=None, step=None, gan=False):
    loc = torch.zeros([args.latent_size]).to(args.device)
    scale = torch.ones([args.latent_size]).to(args.device)
    prior = torch.distributions.normal.Normal(loc, scale)

    context_tokens = decoder_tokenizer.encode('<BOS>')

    nsamples_per_class = args.nsamples // 2
    nbatches = nsamples_per_class // args.per_gpu_eval_batch_size

    gen = []
    for label in [0, 1]:
        senti_attributes = label * torch.ones(args.per_gpu_eval_batch_size, dtype=torch.long)

        data_of_the_label = 0

        #for nb in tqdm(range(nbatches)):
        while True:

            ## generate
            with torch.no_grad():
                latent_z = prior.sample(sample_shape=[args.per_gpu_eval_batch_size])

                if gan:
                    latent_z = model_vae.latent_generator(latent_z[:,:-model_vae.nattr])
                    latent_z_new = _add_attributes(latent_z.to(args.device), senti_attributes.to(args.device), model_vae, concat=True)
                else:
                    latent_z_new = _add_attributes(latent_z.to(args.device), senti_attributes.to(args.device), model_vae, concat=False)

                past = latent_z_new
                past = past.to(args.device)

                out, out_mask = sample_sequence_conditional_batch(
                    model=model_vae.decoder,
                    context=context_tokens[0] * torch.ones(args.per_gpu_eval_batch_size, 1, dtype=torch.long),
                    past=past,
                    length=args.max_seq_length,
                    temperature=args.temperature,
                    top_k=args.top_k,
                    top_p=args.top_p,
                    device=args.device,
                    decoder_tokenizer=decoder_tokenizer
                )

            for b in range(out.shape[0]):
                text_x1 = decoder_tokenizer.decode(out[b,:].tolist(), clean_up_tokenization_spaces=False)
                text_x1 = text_x1.split(decoder_tokenizer.eos_token)[0].strip()

                text_x1 = text_x1.replace(' .', ' . ')
                text_x1 = text_x1.replace(' ?', ' ? ')
                text_x1 = text_x1.replace(' !', ' ! ')
                text_x1 = text_x1.replace('\n', ' ')
                tmp_text_x1 = tokenize.sent_tokenize(text_x1.strip())
                if len(tmp_text_x1) == 0:
                    #print(text_x1)
                    continue
                else:
                    text_x1 = tmp_text_x1[0]

                #if len(text_x1.split()) < 4 or 'dr. dr.' in text_x1:  # 
                #    continue

                gen.append('%d\t' % label + text_x1)

                data_of_the_label += 1
                if data_of_the_label % 100 == 0:
                    print('.', end='', flush=True)

            if data_of_the_label >= nsamples_per_class:
                break


    print('#gen: ', len(gen))

    fn_prefix = "gen_fromp_prior_results"
    if gan:
        fn_prefix += "_gan"
    gen_file_name = "{}_{}.txt".format(fn_prefix, step)
    if path is None:
        path = args.output_dir
        gen_file_name = "{}_{}_n{}.txt".format(fn_prefix, step, args.nsamples)

    if not os.path.exists(path):
        os.makedirs(path)

    output_gen_file = os.path.join(path, gen_file_name)
    with open(output_gen_file, "w") as writer:
        pn = 3
        for i, text in enumerate(gen):
            if i < pn:
                logger.info(text)
            writer.write(text + '\n')

    return output_gen_file


def evaluate_continuous(args, model_vae, encoder_tokenizer, decoder_tokenizer, classifier_tokenizer,
                        senti_classifier=None, categ_classifier=None, output_dir=None, step=None):
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = output_dir

    if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(eval_output_dir)

    #args.per_gpu_eval_batch_size = 1 #
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)

    # Note that DistributedSampler samples randomly

    #eval_dataloader = build_dataload_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True)
    ## Eval!
    #logger.info("***** Running evaluation {} *****".format(prefix))
    #logger.info("  Num examples = %d", len(eval_dataloader))
    #logger.info("  Batch size = %d", args.eval_batch_size)

    model_vae.eval()
    model_vae = model_vae.module if hasattr(model_vae, 'module') else model_vae  # Take care of distributed/parallel training

    ## Generate from prior
    if args.continuous_eval_prior:
        prior_gen_fn = evaluate_generation_fromp_prior(model_vae, decoder_tokenizer, args, path=eval_output_dir, step=step)

        eval_dataset = load_and_cache_examples(args, decoder_tokenizer, evaluate=True, path=prior_gen_fn)
        eval_sampler = SequentialSampler(eval_dataset)
        eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

        logger.info("***** Running evaluation -- prior generation *****")
        logger.info("  Num examples = %d", len(eval_dataloader))
        logger.info("  Batch size = %d", args.eval_batch_size)

        senti_preds = []
        senti_out_label_ids = []
        senti_logits = []
        senti_eval_loss = 0
        categ_preds = []
        categ_out_label_ids = []
        categ_logits = []
        categ_eval_loss = 0
        nb_eval_steps = 0
        for batch in tqdm(eval_dataloader, desc="Evaluating GEN ACCU"):
            batch = tuple(t.to(args.device) for t in batch)
            x, x_lengths, attributes = batch

            x = x[:, 1:] # remove BOS
            sample_mask = (x != 50257).long() * (x != 50259).long() # 50257 is the padding token for GPT2
            x = _mask_with_eos(x, sample_mask, classifier_tokenizer.eos_token_id)
            x.to(args.device)
            #print(sample_mask)
            #print('x')
            #print(x[:5])
            #print(x_lengths[:5])
            #exit()

            # sentiment
            tmp_eval_loss, tmp_preds, tmp_out_label_ids, tmp_logits = _get_clas_accuracy(
                senti_classifier, x, _get_sentiment_attributes(attributes, args.multi_attribute), classifier_tokenizer
            )

            senti_eval_loss += tmp_eval_loss
            senti_preds += tmp_preds
            senti_out_label_ids += tmp_out_label_ids
            senti_logits += tmp_logits

            # category
            tmp_eval_loss, tmp_preds, tmp_out_label_ids, tmp_logits = _get_clas_accuracy(
                categ_classifier, x, _get_sentiment_attributes(attributes, args.multi_attribute), classifier_tokenizer
            )
            categ_eval_loss += tmp_eval_loss
            categ_preds += tmp_preds
            categ_out_label_ids += tmp_out_label_ids
            categ_logits += tmp_logits

            nb_eval_steps += 1

        senti_eval_loss = senti_eval_loss / nb_eval_steps
        senti_out_label_ids = [i[0] for i in senti_out_label_ids]
        senti_accu = (np.asarray(senti_preds) == np.asarray(senti_out_label_ids)).mean()
        #print(senti_preds)
        #print(senti_out_label_ids)

        categ_eval_loss = categ_eval_loss / nb_eval_steps
        categ_out_label_ids = [i[0] for i in categ_out_label_ids]
        categ_accu = (np.asarray(categ_preds) == np.asarray(categ_out_label_ids)).mean()
        #print(categ_preds)
        #print(categ_out_label_ids)

        result = {
            "senti_accuracy": senti_accu,
            "senti_clas_loss": senti_eval_loss,
            "categ_accuracy": categ_accu,
            "categ_clas_loss": categ_eval_loss,
        }


    else:  ## TST

        senti_preds = []
        senti_out_label_ids = []
        senti_logits = []
        senti_eval_loss = 0
        categ_preds = []
        categ_out_label_ids = []
        categ_logits = []
        categ_eval_loss = 0
        nb_eval_steps = 0

        gen = []
        gen_labels = []
        nb_eval_steps = 0
        refs = []
        hyps = []
        hyps_ori = []

        eval_dataloader = build_dataload_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True)
        for batch in tqdm(eval_dataloader, desc="Evaluating TST "):
            batch = tuple(t.to(args.device) for t in batch)
            x0, x1, x_lengths, attributes = batch

            context_tokens = decoder_tokenizer.encode('<BOS>')

            with torch.no_grad():

                if isinstance(model_vae.encoder, BertForLatentConnector):
                    pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]
                elif isinstance(model_vae.encoder, GPT2Encoder):
                    encoder_input = x0
                    if args.multi_attribute:
                        if args.cond_a and args.cond_c:
                            raise NotImplementedError
                        elif args.cond_a:
                            encoder_input = model_vae._append_a(x0, _get_sentiment_attributes(attributes, args.multi_attribute))
                        elif args.cond_c:
                            encoder_input = model_vae._append_a(x0, _get_category_attributes(attributes, args.multi_attribute))
                    else:
                        if args.cond_c:
                            raise NotImplementedError
                        if args.cond_a:
                            encoder_input = model_vae._append_a(x0, attributes)

                    pooled_hidden_fea = model_vae.encoder(model_vae._append_a(x0, attributes),
                                                          #attention_mask=(x0 > 0).float(),
                                                          end_token_id_or_embeds=encoder_tokenizer.eos_token_id)
                else:
                    raise NotImplementedError

                # Connect hidden feature to the latent space
                # latent_z, loss_kl = model_vae.connect(pooled_hidden_fea)
                mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
                latent_z = mean.squeeze(1)

                latent_z_new = _add_attributes(
                    latent_z,
                    _get_sentiment_attributes(attributes, args.multi_attribute),
                    model_vae, concat=False)

                past = latent_z_new

                out, out_mask = sample_sequence_conditional_batch(
                    model=model_vae.decoder,
                    context=context_tokens[0] * torch.ones(x0.shape[0], 1, dtype=torch.long),
                    past=past,
                    length=x_lengths[:,1].max(),
                    #temperature=args.temperature,
                    #top_k=args.top_k,
                    #top_p=args.top_p,
                    device=args.device,
                    decoder_tokenizer=decoder_tokenizer
                )
                out = _mask_with_eos(out, out_mask, classifier_tokenizer.eos_token_id)

                ## flip attributes, 
                #latent_z_new = _add_attributes(latent_z, _get_reverse_attributes(attributes, args.multi_attribute), model_vae, concat=False)
                latent_z_new = _add_attributes(latent_z, 1 - _get_sentiment_attributes(attributes, args.multi_attribute), model_vae, concat=False)

                past = latent_z_new
                out_2, out_2_mask = sample_sequence_conditional_batch(
                    model=model_vae.decoder,
                    context=context_tokens[0] * torch.ones(x0.shape[0], 1, dtype=torch.long),
                    past=past,
                    length=x_lengths[:,1].max(),
                    #temperature=args.temperature,
                    #top_k=args.top_k,
                    #top_p=args.top_p,
                    device=args.device,
                    decoder_tokenizer = decoder_tokenizer
                )
                out_2 = _mask_with_eos(out_2, out_2_mask, classifier_tokenizer.eos_token_id)

                # sentiment
                classifier = senti_classifier

                if args.eval_self_accuracy:
                    tmp_eval_loss, tmp_preds, tmp_out_label_ids, tmp_logits = _get_clas_accuracy(
                        classifier, out, _get_sentiment_attributes(attributes, args.multi_attribute), classifier_tokenizer
                    )
                else:
                    tmp_eval_loss, tmp_preds, tmp_out_label_ids, tmp_logits = _get_clas_accuracy(
                        classifier, out_2, 1 - _get_sentiment_attributes(attributes, args.multi_attribute), classifier_tokenizer
                    )

                senti_eval_loss += tmp_eval_loss
                senti_preds += tmp_preds
                senti_out_label_ids += tmp_out_label_ids
                senti_logits += tmp_logits

                # category
                classifier = categ_classifier

                if args.eval_self_accuracy:
                    tmp_eval_loss, tmp_preds, tmp_out_label_ids, tmp_logits = _get_clas_accuracy(
                        classifier, out, _get_sentiment_attributes(attributes, args.multi_attribute), classifier_tokenizer
                    )
                else:
                    tmp_eval_loss, tmp_preds, tmp_out_label_ids, tmp_logits = _get_clas_accuracy(
                        classifier, out_2, 1 - _get_sentiment_attributes(attributes, args.multi_attribute), classifier_tokenizer
                    )
                categ_eval_loss += tmp_eval_loss
                categ_preds += tmp_preds
                categ_out_label_ids += tmp_out_label_ids
                categ_logits += tmp_logits

                nb_eval_steps += 1

                for b in range(x0.shape[0]):
                    # 
                    # clean_up_tokenization_spaces=True
                    text_x0 = encoder_tokenizer.decode(x0[b,:x_lengths[b,0]].tolist(), clean_up_tokenization_spaces=False)[0]

                    text_x1_data = decoder_tokenizer.decode(x1[b,1:].tolist(), clean_up_tokenization_spaces=False)
                    text_x1_data = text_x1_data.split('<EOS>')[0].strip()
                    refs.append([text_x1_data.strip()])

                    text_x1 = decoder_tokenizer.decode(out[b,:].tolist(), clean_up_tokenization_spaces=False)
                    text_x1 = text_x1.split('<EOS>')[0].strip()
                    hyps_ori.append(text_x1)

                    text_x1_tst = decoder_tokenizer.decode(out_2[b,:].tolist(), clean_up_tokenization_spaces=False)
                    text_x1_tst = text_x1_tst.split('<EOS>')[0].strip()

                    gen.append([text_x1_data.strip(), text_x1.strip(), text_x1_tst.strip()])
                    gen_labels.append([attributes[b].item(), attributes[b].item(), _get_reverse_attributes(attributes, args.multi_attribute)[b].item()])

                    hyps.append(text_x1_tst)

        senti_eval_loss = senti_eval_loss / nb_eval_steps
        #senti_out_label_ids = [i[0] for i in senti_out_label_ids]
        senti_accu = (np.asarray(senti_preds) == np.asarray(senti_out_label_ids)).mean()

        categ_eval_loss = categ_eval_loss / nb_eval_steps
        #categ_out_label_ids = [i[0] for i in categ_out_label_ids]
        categ_accu = (np.asarray(categ_preds) == np.asarray(categ_out_label_ids)).mean()

        gen_file_name = "eval_continuous_recontruction_results_%d.txt" % step
        output_gen_file = os.path.join(eval_output_dir, gen_file_name)
        with open(output_gen_file, "w") as writer:
            pn = 5
            for g in gen:
                assert len(g) == 3
                if pn > 0:
                    logger.info("  %s \n %s \n %s\n", g[0], g[1], g[2])
                writer.write("%s\n%s\n%s\n\n" % (g[0], g[1], g[2]))
                pn -= 1

        assert len(senti_preds) == len(gen_labels)

        gen_file_name = "eval_continuous_attributes_results_%d.txt" % step
        output_gen_file = os.path.join(eval_output_dir, gen_file_name)
        with open(output_gen_file, "w") as writer:
            pn = 5
            l = list(zip([g[-1] for g in gen], senti_preds, senti_out_label_ids, senti_logits, categ_preds, categ_out_label_ids, categ_logits))
            #l.sort()
            text, senti_preds, senti_out_label_ids, senti_logits, categ_preds, categ_out_label_ids, categ_logits = zip(*l)
            for p, g, t, lo, cp, cg, clo in zip(senti_preds, senti_out_label_ids, text, senti_logits, categ_preds, categ_out_label_ids, categ_logits):
                writer.write('%d\t%.5f\t%.5f\t%d\t%d\t%.5f\t%.5f\t%d\t%s\n' % (p, lo[0], lo[1], g, cp, clo[0], clo[1], cg, t))

        bleu_ori = tx.evals.corpus_bleu_moses(refs, hyps_ori)
        bleu = tx.evals.corpus_bleu_moses(refs, hyps)

        result = {
            "senti_accuracy": senti_accu,
            "senti_clas_loss": senti_eval_loss,
            "categ_accuracy": categ_accu,
            "categ_clas_loss": categ_eval_loss,
            "bleu": bleu,
            "bleu_ori": bleu_ori
        }

        output_eval_file = os.path.join(eval_output_dir, "eval_results_%d.txt" % step)
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    return result


def evaluate_rec(args, model_vae, encoder_tokenizer, decoder_tokenizer, table_name, prefix="", subset="test"):
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = args.output_dir

    if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(eval_output_dir)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)

    if subset == 'test':
        eval_dataloader = build_dataload_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True)
    elif subset == 'train':
        eval_dataloader = build_dataload_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=False)
    logger.info("***** Running evaluation on {} dataset *****".format(subset))

    # Note that DistributedSampler samples randomly
    # eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
    # eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    # eval_dataloader = build_dataload_and_cache_examples(args, [encoder_tokenizer, decoder_tokenizer], evaluate=True)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataloader))
    logger.info("  Batch size = %d", args.eval_batch_size)

    model_vae.eval()
    model_vae =  model_vae.module if hasattr(model_vae, 'module') else model_vae  # Take care of distributed/parallel training
    nll_s, nll_w = calc_rec(model_vae, eval_dataloader, args, ns=1)

    result = {
        "rec_w": nll_w, "rec_s": nll_s
    }

    output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        for key in sorted(result.keys()):
            logger.info("%s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

    return result


def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--train_data_file", default=None, type=str, required=True,
                        help="The input training data file (a text file).")
    parser.add_argument("--checkpoint_dir", default=None, type=str,
                        help="The directory where checkpoints are saved.")
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--dataset", default=None, type=str, help="The dataset.")

    ## Semi supervised  # 
    parser.add_argument("--semi", action='store_true', help="semi supervised")
    parser.add_argument("--train_data_file_2", default="", type=str, required=False,
                        help="The input training data file (a text file).")
    parser.add_argument("--per_gpu_train_batch_size_2", default=2, type=int,
                        help="Batch size per GPU/CPU for training.")

    ## Other parameters
    parser.add_argument("--eval_data_file", default=None, type=str,
                        help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
    parser.add_argument("--ExpName", default="", type=str,
                        help="The experiment name used in Azure Table.")
    parser.add_argument("--save_bert_gpt_init", action='store_true',
                        help="Use Philly for computing.")
    parser.add_argument("--length_weighted_loss", action='store_true',
                        help="Use sentence length re-weight the reconstruction loss.")


    ## Encoder options
    parser.add_argument("--encoder_model_type", default="bert", type=str,
                        help="The encoder model architecture to be fine-tuned.")
    parser.add_argument("--encoder_model_name_or_path", default="bert-base-cased", type=str,
                        help="The encoder model checkpoint for weights initialization.")
    parser.add_argument("--encoder_config_name", default="", type=str,
                        help="Optional pretrained config name or path if not the same as model_name_or_path")
    parser.add_argument("--encoder_tokenizer_name", default="", type=str,
                        help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")

    # 
    parser.add_argument("--cond_a", action='store_true', help="z conditioning on a")
    parser.add_argument("--cond_c", action='store_true', help="z conditioning on c")

    ## Decoder options
    parser.add_argument("--decoder_model_type", default="gpt2", type=str,
                        help="The decoder model architecture to be fine-tuned.")
    parser.add_argument("--decoder_model_name_or_path", default="bert-base-cased", type=str,
                        help="The decoder model checkpoint for weights initialization.")
    parser.add_argument("--decoder_config_name", default="", type=str,
                        help="Optional pretrained config name or path if not the same as model_name_or_path")
    parser.add_argument("--decoder_tokenizer_name", default="", type=str,
                        help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")

    ## Classifier options
    parser.add_argument("--classifier_model_type", default="gpt2-clas", type=str,
                        help="The classifier model architecture to be fine-tuned.")
    parser.add_argument("--classifier_model_name_or_path", default="gpt2", type=str,
                        help="The classifier model checkpoint for weights initialization.")
    parser.add_argument("--classifier_config_name", default="", type=str,
                        help="Optional pretrained config name or path if not the same as model_name_or_path")
    parser.add_argument("--classifier_tokenizer_name", default="", type=str,
                        help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")

    parser.add_argument("--classifier_dir", default="", type=str, help="")

    ## Gumbel softmax
    parser.add_argument("--temperature", type=float, default=1.0,
                        help="temperature init")
    parser.add_argument("--temperature_anneal_factor", type=float, default=1.0,
                        help="annealing factor")
    parser.add_argument("--temperature_anneal_iters", type=float, default=4000,
                        help="temperature annealling frequency")
    parser.add_argument("--lambda_clas", type=float, default=1.0,
                        help="classification loss weight")
    parser.add_argument("--lambda_recon", type=float, default=1.0,
                        help="recon loss weight")
    parser.add_argument("--gumbel_hard", action='store_true',
                        help="Gumbel softmax hard sample.")
    parser.add_argument("--use_gumbel", action='store_true',
                        help="Gumbel softmax")

    parser.add_argument("--lambda_c_loss", type=float, default=0,
                        help="c classification loss weight")
    parser.add_argument("--lambda_reg_z", type=float, default=0,
                        help="regularization loss weight")
    parser.add_argument("--lambda_reg_z_c", type=float, default=0,
                        help="regularization c z loss weight")


    ## Variational auto-encoder
    parser.add_argument("--latent_size", default=32, type=int, help="Latent space dimension.")
    parser.add_argument("--use_deterministic_connect", action='store_true',
                        help="Use deterministic inference to generate latent codes, i.e., standard auto-encoders.")
    parser.add_argument("--use_pretrained_model", action='store_true',
                        help="Use pre-trained auto-encoder models as the initialization")
    parser.add_argument("--latent_as_gpt_memory", default=1, type=int, help="Latent vector as memery for GPT2 to attend.")
    parser.add_argument("--latent_as_gpt_emb", default=1, type=int, help="Latent vector as embeddings for GPT2.")

    ## Attributes
    parser.add_argument("--attribute_dim", default=1, type=int, help="Latent space dimension.")
    parser.add_argument("--multi_attribute", action='store_true', help="")

    ## Objective functions
    parser.add_argument("--mlm", action='store_true',
                        help="Train with masked-language modeling loss instead of language modeling.")
    parser.add_argument("--mlm_probability", type=float, default=0.15,
                        help="Ratio of tokens to mask for masked language modeling loss")
    parser.add_argument("--beta", type=float, default=1.0,
                        help="The weighting hyper-parameter of the KL term in VAE")


    parser.add_argument("--cache_dir", default="", type=str,
                        help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
    parser.add_argument("--max_seq_length", default=512, type=int,
                        help="Optional input sequence length before tokenization. The sequence will be dropped if it is longer the max_seq_length")
    parser.add_argument("--block_size", default=-1, type=int,
                        help="Optional input sequence length after tokenization."
                             "The training dataset will be truncated in block of this size for training."
                             "Default to the model max input length for single sentence inputs (take into account special tokens).")
    parser.add_argument("--do_train", action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval", action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_eval_rec", action='store_true',
                        help="Whether to run eval reconstruction on a set of models.")
    parser.add_argument("--evaluate_during_training", action='store_true',
                        help="Run evaluation during training at each logging step.")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")


    parser.add_argument("--eval_elbo", action='store_true', help="ELBO")
    parser.add_argument("--eval_self_accuracy", action='store_true', help="Self accuracy")
    parser.add_argument("--do_continuous_eval", action='store_true', help="")
    parser.add_argument("--continuous_eval_sorted", action='store_true', help="")
    parser.add_argument("--continuous_eval_sorted_reverse", action='store_true', help="")
    parser.add_argument("--do_infer_resampling", action='store_true', help="")
    parser.add_argument("--do_train_a_layer", action='store_true', help="")
    parser.add_argument("--do_train_gan", action='store_true', help="")
    parser.add_argument("--do_eval_a_layer", action='store_true', help="")
    parser.add_argument("--continuous_eval_prior", action='store_true', help="")
    parser.add_argument("--eval_gender", action='store_true', help="")
    parser.add_argument("--eval_priority_list_file", default="", type=str, help="")
    parser.add_argument("--eval_list_file", default="", type=str, help="")
    parser.add_argument("--nsamples", type=int, default=100)
    parser.add_argument("--do_gen_from_prior", action='store_true',
                        help="Whether to generate from prior.")
    parser.add_argument("--gen_with_gan", action='store_true')
    parser.add_argument("--top_k", type=int, default=0)
    parser.add_argument("--top_p", type=float, default=0)

    # backdoor adjustment
    parser.add_argument("--num_z", type=int, default=1)
    parser.add_argument("--num_resamples", type=int, default=1000)
    parser.add_argument('--sample_fn', type=str, default='', help="")
    parser.add_argument("--trim_low", type=float, default=0)
    parser.add_argument("--trim_high", type=float, default=10)


    # Training Schedules
    parser.add_argument("--ratio_increase", default=0.25, type=float,
                        help="Learning schedule, the percentage for the annealing stage.")
    parser.add_argument("--ratio_zero", default=0.25, type=float,
                        help="Learning schedule, the percentage for the pure auto-encoding stage.")
    parser.add_argument("--fb_mode", default=0, type=int,
                        help="free bit training mode.")
    parser.add_argument("--dim_target_kl", default=3.0, type=float,
                        help="dim_target_kl free bit training mode.")
    parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--learning_rate", default=5e-5, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.0, type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs", default=1.0, type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--max_steps", default=-1, type=int,
                        help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
    parser.add_argument("--warmup_steps", default=0, type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--use_philly", action='store_true',
                        help="Use Philly for computing.")
    parser.add_argument("--use_pretrained_vae", action='store_true',
                        help="Use use_pretrained_vae as initialization, where beta value is specified in the folder")
    parser.add_argument("--use_random_weight", action='store_true',
                        help="Use random weights as initialization")


    ## IO: Logging and Saving
    parser.add_argument('--logging_steps', type=int, default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps', type=int, default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument("--eval_all_checkpoints", action='store_true',
                        help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
    parser.add_argument("--no_cuda", action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--overwrite_output_dir', action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument('--overwrite_cache', action='store_true',
                        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--seed', type=int, default=42,
                        help="random seed for initialization")
    parser.add_argument('--gloabl_step_eval', type=int, default=661,
                        help="Evaluate the results at the given global step")

    # Precision & Distributed Training
    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
    parser.add_argument('--fp16_opt_level', type=str, default='O1',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
    args = parser.parse_args()

    if args.decoder_model_type in ["bert", "roberta"] and not args.mlm:
        raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
                         "flag (masked language modeling).")
    if args.eval_data_file is None and args.do_eval:
        raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
                         "or remove the --do_eval argument.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
        raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt = '%m/%d/%Y %H:%M:%S',
                        level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
                    args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)

    args.ExpName = 'Vae_' + args.dataset + '_Nz_' + str(args.latent_size)  + '_Beta_'  + str(args.beta) + '_Dkl_' + str(args.dim_target_kl) + '_Ra_' + str(args.ratio_increase) + '_R0_' + str(args.ratio_zero)
    table_name = 'Vae' + args.dataset + 'Nz' + str(args.latent_size)
    try:
        ts.create_table(table_name)
    except:
        pass


    # Set seed
    set_seed(args)

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training download model & vocab


    # 
    if args.classifier_dir is not "":
        output_classifier_dir = args.classifier_dir
    else:
        #output_classifier_dir = os.path.join(args.checkpoint_dir, 'checkpoint-classifier-{}'.format(global_step))
        #output_classifier_dir = os.path.join(args.checkpoint_dir, 'checkpoint-classifier-{}-biased'.format(global_step))

        #output_classifier_dir = os.path.join('/mnt/efs/fs2/hzt/causal/Optimus/outputs/classifier/yelp_sentiment_bias_gpt2_5ep/checkpoint-28000')
        output_classifier_dir = "/mnt/efs/fs2/hzt/causal/Optimus/data/multi_yelp_tst/full_5_15_clean_semi/classifier_senti_finetune_weighted_01_09/checkpoint-180"

        #output_classifier_dir = os.path.join('/mnt/efs/fs2/hzt/causal/Optimus/outputs/finetune_lm/yelp_na50_10ep/checkpoint-31250-69000/checkpoint-classifier-69000')
        #output_classifier_dir = os.path.join('/mnt/efs/fs2/hzt/causal/Optimus/outputs/classifier/yelp_sentiment_unbias_large_gpt2_5ep/checkpoint-112000')
        #output_classifier_dir = os.path.join('/mnt/efs/fs2/hzt/causal/Optimus/outputs/finetune_lm/s3-bias_yelp_na50_10ep_bz64_vae/checkpoint-31250-81000/checkpoint-classifier-81000')


    # Load Optimius pre-trained model and tokenizer
    if args.use_pretrained_model:
        args.encoder_model_type = args.encoder_model_type.lower()
        args.decoder_model_type = args.decoder_model_type.lower()

        global_step = args.gloabl_step_eval

        output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))
        output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
        output_full_dir    = os.path.join(args.checkpoint_dir, 'checkpoint-full-{}'.format(global_step))

        #checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
        checkpoints = [ [output_encoder_dir, output_decoder_dir, output_classifier_dir] ]  # 
        logger.info("Evaluate the following checkpoints: %s", checkpoints)

        # Load a trained Encoder model and vocabulary
        encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
        model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)
        tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)

        model_encoder.to(args.device)
        if args.block_size <= 0:
            args.block_size = tokenizer_encoder.max_len_single_sentence  # Our input block size will be the max possible for the model
        args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)

        # Load a trained Decoder model and vocabulary
        decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
        model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)
        tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
        model_decoder.to(args.device)
        if args.block_size <= 0:
            args.block_size = tokenizer_decoder.max_len_single_sentence  # Our input block size will be the max possible for the model
        args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)

        # 
        # Load a trained Classifier model and vocabulary
        classifier_config_class, classifier_model_class, classifier_tokenizer_class = MODEL_CLASSES[args.classifier_model_type]
        model_classifier = classifier_model_class.from_pretrained(output_classifier_dir)
        tokenizer_classifier = classifier_tokenizer_class.from_pretrained(args.classifier_tokenizer_name if args.classifier_tokenizer_name else args.classifier_model_name_or_path, do_lower_case=args.do_lower_case)
        model_classifier.to(args.device)

        # Load full model
        checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'))


    else:

        # Load BERT and GPT weights (As an alternaive, one may train a VAE for this small)

        ## Encoder
        encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type.lower()]
        encoder_config = encoder_config_class.from_pretrained(args.encoder_config_name if args.encoder_config_name else args.encoder_model_name_or_path)
        tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)
        if args.block_size <= 0:
            args.block_size = tokenizer_encoder.max_len_single_sentence  # Our input block size will be the max possible for the model
        args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)
        model_encoder = encoder_model_class.from_pretrained(args.encoder_model_name_or_path, from_tf=bool('.ckpt' in args.encoder_model_name_or_path), config=encoder_config, latent_size=args.latent_size)
        model_encoder.to(args.device)  # 

        ## Decoder
        decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type.lower()]
        decoder_config = decoder_config_class.from_pretrained(args.decoder_config_name if args.decoder_config_name else args.decoder_model_name_or_path)
        tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
        if args.block_size <= 0:
            args.block_size = tokenizer_decoder.max_len_single_sentence  # Our input block size will be the max possible for the model
        args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)

        if args.latent_as_gpt_emb + args.latent_as_gpt_memory == 0:
            return # latent vector should pass into GPT to decode
        else:
            latent_as_gpt_emb = True if args.latent_as_gpt_emb == 1 else False
            latent_as_gpt_memory = True if args.latent_as_gpt_memory == 1 else False

        setattr(decoder_config, "latent_size", args.latent_size)
        model_decoder = decoder_model_class.from_pretrained(args.decoder_model_name_or_path, from_tf=bool('.ckpt' in args.decoder_model_name_or_path), config=decoder_config, latent_size=args.latent_size, latent_as_gpt_emb=latent_as_gpt_emb, latent_as_gpt_memory=latent_as_gpt_memory)
        model_decoder.to(args.device)  # 

        # 
        # Load a trained Classifier model and vocabulary
        classifier_config_class, classifier_model_class, classifier_tokenizer_class = MODEL_CLASSES[args.classifier_model_type.lower()]
        model_classifier = classifier_model_class.from_pretrained(output_classifier_dir)
        tokenizer_classifier = classifier_tokenizer_class.from_pretrained(args.classifier_tokenizer_name if args.classifier_tokenizer_name else args.classifier_model_name_or_path, do_lower_case=args.do_lower_case)
        model_classifier.to(args.device)


    # Save the init weights of BERT and GPT-2, so that we can load from local (Some infra requires so)
    if args.save_bert_gpt_init:

        raise NotImplementedError  # 

        encoder_path = os.path.join(args.output_dir, f"initial-models-tokenization-enoder-{args.latent_size}")
        if not os.path.exists(encoder_path): os.makedirs(encoder_path)
        model_encoder.save_pretrained(encoder_path)
        tokenizer_encoder.save_pretrained(encoder_path)

        decoder_path = os.path.join(args.output_dir, f"initial-models-tokenization-decoder-{args.latent_size}")
        if not os.path.exists(decoder_path): os.makedirs(decoder_path)
        model_decoder.save_pretrained(decoder_path)
        tokenizer_decoder.save_pretrained(decoder_path)

        return


    # : Adding Padding token to GPT2 encoder
    if isinstance(tokenizer_encoder, GPT2Tokenizer):
        special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
        num_added_toks = tokenizer_encoder.add_special_tokens(special_tokens_dict)
        print('We have added', num_added_toks, 'tokens to GPT2 encoder')
        model_encoder.resize_token_embeddings(len(tokenizer_encoder))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
        assert tokenizer_encoder.pad_token == '<PAD>'

    # Chunyuan: Add Padding token to GPT2 decoder
    special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
    num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict)
    print('We have added', num_added_toks, 'tokens to GPT2')
    model_decoder.resize_token_embeddings(len(tokenizer_decoder))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
    assert tokenizer_decoder.pad_token == '<PAD>'

    # : Add Padding token to GPT2 classifier
    special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'}
    num_added_toks = tokenizer_classifier.add_special_tokens(special_tokens_dict)
    print('We have added', num_added_toks, 'tokens to GPT2 Classifier')
    model_classifier.resize_token_embeddings(len(tokenizer_classifier))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
    assert tokenizer_classifier.pad_token == '<PAD>'
    assert tokenizer_decoder.eos_token_id == tokenizer_classifier.eos_token_id
    print('eos token id: ', tokenizer_classifier.eos_token_id)

    # model_decoder.to(args.device)

    #model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args)
    model_vae = VAEClas(model_encoder, model_decoder, model_classifier,
                        tokenizer_encoder, tokenizer_decoder, tokenizer_classifier, args)  # 

    # pdb.set_trace()
    if args.use_random_weight:
        model_vae.apply(weights_init_rondom)

    if args.use_pretrained_model:
        #model_vae.load_state_dict(checkpoint['model_state_dict'])
        model_vae.load_state_dict(checkpoint['model_state_dict'], strict=False)  #  ======================

        # 
        classifier_config_class, classifier_model_class, classifier_tokenizer_class = MODEL_CLASSES[args.classifier_model_type]
        model_classifier = classifier_model_class.from_pretrained(output_classifier_dir)
        model_classifier.to(args.device)
        model_classifier.resize_token_embeddings(len(tokenizer_classifier))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
        model_vae.classifier = model_classifier

        logger.info("Pre-trained Optimus is successfully loaded")
    model_vae.to(args.device) #

    # on_gpu = next(model_vae.parameters()).is_cuda


    if args.local_rank == 0:
        torch.distributed.barrier()  # End of barrier to make sure only the first process in distributed training download model & vocab

    logger.info("Training/evaluation parameters %s", args)

    ##############################
    # Training
    global_step= 0
    if args.do_train:
        if args.local_rank not in [-1, 0]:
            torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache

        train_dataloader = build_dataload_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False, shuffle=True)
        #train_dataloader = build_dataload_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=True, shuffle=False) #TODO #evaluate=False, shuffle=True)

        if args.local_rank == 0:
            torch.distributed.barrier()

        #global_step, tr_loss, optimizer = train(args, train_dataloader, model_vae, tokenizer_encoder, tokenizer_decoder, table_name)
        global_step, tr_loss, optimizer = train(args, train_dataloader, model_vae, tokenizer_encoder, tokenizer_decoder, tokenizer_classifier, table_name)  # 
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

    #return

    ## Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
    #if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
    #    save_checkpoint(model_vae, optimizer, global_step, args)

    #classifier_config_class, classifier_model_class, classifier_tokenizer_class = MODEL_CLASSES[args.classifier_model_type]
    #model_classifier = classifier_model_class.from_pretrained(output_classifier_dir)
    #model_classifier.to(args.device)
    #model_classifier.resize_token_embeddings(len(tokenizer_classifier))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
    #model_vae.classifier = model_classifier

    #save_checkpoint(model_vae, optimizer, global_step, args)  #  =================
    #exit()


    ##############################
    # Evaluation the metrics of VAE models, including PPL, MI, AU
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        if global_step == 0:
            global_step = args.gloabl_step_eval

        output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))  # 
        output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
        output_full_dir    = os.path.join(args.checkpoint_dir, 'checkpoint-full-{}'.format(global_step))
        checkpoint_dir = [output_encoder_dir, output_decoder_dir, output_full_dir]

        logger.info("Evaluate the following checkpoint: %s", checkpoint_dir[-1])
        global_step = checkpoint_dir[-1].split('-')[-1] if len(checkpoint_dir) > 1 else ""

        checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'))
        model_vae.load_state_dict(checkpoint['model_state_dict'])  #  ======================
        logger.info(f"Pre-trained Optimus is successfully loaded: {output_full_dir}")
        model_vae.to(args.device)

        # 
        #classifier = classifier_model_class.from_pretrained('/mnt/efs/fs2/hzt/causal/Optimus/outputs/classifier/yelp_sentiment_unbias_large_gpt2_5ep/checkpoint-112000')
        #classifier = classifier_model_class.from_pretrained('/mnt/efs/fs2/hzt/causal/Optimus/outputs//classifier/yelp_sentiment_bias_gpt2_5ep/checkpoint-40000')
        #classifier = classifier_model_class.from_pretrained('/mnt/efs/fs2/hzt/causal/Optimus/outputs/finetune_lm/s3-bias_yelp_na50_10ep_bz64_vae/checkpoint-31250-81000/checkpoint-classifier-81000')
        #classifier.to(args.device)

        result = evaluate(args, model_vae, tokenizer_encoder, tokenizer_decoder, tokenizer_classifier, table_name, prefix=global_step, subset='test')
                          #classifier=model_vae.classifier)
                          #classifier=classifier)
        result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
        results.update(result)

        output_eval_file = os.path.join(args.output_dir, "eval_vae_results_test.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(results.keys()):
                logger.info("%s = %s", key, str(results[key]))
                writer.write("%s = %s\n" % (key, str(results[key])))
        logger.info(f"The testing results are successfully saved: {output_eval_file}")


    # Generation
    if args.do_gen_from_prior:
        if global_step == 0:
            global_step = args.gloabl_step_eval

        output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))  # 
        output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
        output_full_dir    = os.path.join(args.checkpoint_dir, 'checkpoint-full-{}'.format(global_step))
        checkpoint_dir = [output_encoder_dir, output_decoder_dir, output_full_dir]

        logger.info("Evaluate the following checkpoint: %s", checkpoint_dir[-1])
        global_step = checkpoint_dir[-1].split('-')[-1] if len(checkpoint_dir) > 1 else ""

        checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'))
        model_vae.load_state_dict(checkpoint['model_state_dict'], strict=False)  #  ======================

        classifier_config_class, classifier_model_class, classifier_tokenizer_class = MODEL_CLASSES[args.classifier_model_type]
        model_classifier = classifier_model_class.from_pretrained(output_classifier_dir)
        model_classifier.to(args.device)
        model_classifier.resize_token_embeddings(len(tokenizer_classifier))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
        model_vae.classifier = model_classifier

        logger.info(f"Pre-trained Optimus is successfully loaded: {output_full_dir}")
        model_vae.to(args.device)

        evaluate_generation_fromp_prior(model_vae, tokenizer_decoder, args, step=global_step, gan=args.gen_with_gan)


    # Continuous Evaluation
    if args.do_continuous_eval:
        checkpoints_evaluated = {}
        path_to_tbwriter = {}

        senti_classifier, categ_classifier = None, None
        #if args.continuous_eval_prior:
        if args.eval_gender:
            senti_classifier = classifier_model_class.from_pretrained('/mnt/efs/fs2/hzt/causal/Optimus/outputs_gender/classifiers/profession_full_classifier_weighted_01_09_lr2e5/checkpoint-2200')
            senti_classifier.to(args.device)
            categ_classifier = classifier_model_class.from_pretrained('/mnt/efs/fs2/hzt/causal/Optimus/outputs_gender/classifiers/gender_full_classifier_weighted_01_09_lr2e5/checkpoint-3000')
            categ_classifier.to(args.device)
        else:
            senti_classifier = classifier_model_class.from_pretrained('/mnt/efs/fs2/hzt/causal/Optimus/outputs/classifier/yelp_sentiment_unbias_large_gpt2_5ep/checkpoint-112000')
            senti_classifier.to(args.device)
            categ_classifier = classifier_model_class.from_pretrained('/mnt/efs/fs2/hzt/causal/Optimus/outputs/classifier/yelp_category_unbias_large_gpt2_5ep/checkpoint-78000')
            categ_classifier.to(args.device)

        def _get_ckpt_key(path, step):
            return path.strip('/') + '/%d' % step

        def get_conti_summarywriter(output_path):
            return SummaryWriter(logdir=output_path)

        def _eval_checkpoint(path, step):
            # Load checkpoint
            output_encoder_dir = os.path.join(path, 'checkpoint-encoder-{}'.format(step))
            output_decoder_dir = os.path.join(path, 'checkpoint-decoder-{}'.format(step))
            output_full_dir    = os.path.join(path, 'checkpoint-full-{}'.format(step))
            checkpoint_dir = [output_encoder_dir, output_decoder_dir, output_full_dir]

            logger.info("Evaluate the following checkpoint: %s", checkpoint_dir[-1])

            checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'))
            model_vae.load_state_dict(checkpoint['model_state_dict'], strict=False)
            logger.info(f"Pre-trained Optimus is successfully loaded: {output_full_dir}")
            model_vae.to(args.device)

            results = evaluate_continuous(args, model_vae, tokenizer_encoder, tokenizer_decoder, tokenizer_classifier,
                                          senti_classifier=senti_classifier, categ_classifier=categ_classifier, output_dir=path, step=step)
            tb_writer = path_to_tbwriter.get(path, None)
            if tb_writer is None:
                if args.continuous_eval_prior:
                    tb_writer_path = os.path.join(path, 'tb_logs')
                else:
                    tb_writer_path = os.path.join(path, 'tst_tb_logs')
                tb_writer = get_conti_summarywriter(tb_writer_path)
                path_to_tbwriter[path] = tb_writer
            for key, value in results.items():
                tb_writer.add_scalar('c_eval_{}'.format(key), value, step)
                print(str(step), '\t', 'c_eval_{}'.format(key), '\t', value)

        while True:
            with open(args.eval_priority_list_file, 'r') as fin:
                for line in fin:
                    parts = line.strip().split(',')
                    path = parts[0]

                    if args.continuous_eval_prior:
                        ckpt_evaled_fn = os.path.join(path, 'checkpoints_evaluated.txt')
                    else:
                        ckpt_evaled_fn = os.path.join(path, 'tst_checkpoints_evaluated.txt')
                    if os.path.exists(ckpt_evaled_fn):
                        with open(ckpt_evaled_fn, 'r') as ckpt_fin:
                            for line in ckpt_fin:
                                checkpoints_evaluated[line.strip()] = 1

                    steps = [int(s) for s in parts[1:]]
                    for step in steps:
                        ckpt_key = _get_ckpt_key(path, step)
                        if ckpt_key in checkpoints_evaluated:
                            continue

                        _eval_checkpoint(path, step)

                        checkpoints_evaluated[ckpt_key] = 1
                        with open(ckpt_evaled_fn, 'a+') as ckpt_fout:
                            ckpt_fout.write(ckpt_key + '\n')
                            ckpt_fout.flush()

            with open(args.eval_list_file, 'r') as fin:
                for line in fin:
                    path = line.strip()

                    if args.continuous_eval_prior:
                        ckpt_evaled_fn = os.path.join(path, 'checkpoints_evaluated.txt')
                    else:
                        ckpt_evaled_fn = os.path.join(path, 'tst_checkpoints_evaluated.txt')
                    if os.path.exists(ckpt_evaled_fn):
                        with open(ckpt_evaled_fn, 'r') as ckpt_fin:
                            for line in ckpt_fin:
                                checkpoints_evaluated[line.strip()] = 1

                    sub_paths = [f.name for f in os.scandir(path) if f.is_dir()]
                    steps = []
                    for sp in sub_paths:
                        step = sp.split('-')[-1]
                        if step == '' or not is_integer(step):
                            print('"%s" is not a step' % step)
                            continue
                        steps.append(int(step))

                    if args.continuous_eval_sorted:
                        steps.sort()
                    if args.continuous_eval_sorted_reverse:
                        steps.sort(reverse=True)

                    for step in steps:
                        ckpt_key = _get_ckpt_key(path, step)
                        if ckpt_key in checkpoints_evaluated:
                            continue

                        _eval_checkpoint(path, step)

                        checkpoints_evaluated[ckpt_key] = 1
                        with open(ckpt_evaled_fn, 'a+') as ckpt_fout:
                            ckpt_fout.write(ckpt_key + '\n')
                            ckpt_fout.flush()

                        break  # beak immediately
                    break

            time.sleep(60)


    if args.do_infer_resampling:
        if global_step == 0:
            global_step = args.gloabl_step_eval

        output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))  # 
        output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
        output_full_dir    = os.path.join(args.checkpoint_dir, 'checkpoint-full-{}'.format(global_step))
        checkpoint_dir = [output_encoder_dir, output_decoder_dir, output_full_dir]

        logger.info("Evaluate the following checkpoint: %s", checkpoint_dir[-1])
        global_step = checkpoint_dir[-1].split('-')[-1] if len(checkpoint_dir) > 1 else ""

        checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'))
        model_vae.load_state_dict(checkpoint['model_state_dict'], strict=False)  #  ======================

        classifier_config_class, classifier_model_class, classifier_tokenizer_class = MODEL_CLASSES[args.classifier_model_type]
        model_classifier = classifier_model_class.from_pretrained(output_classifier_dir)
        model_classifier.to(args.device)
        model_classifier.resize_token_embeddings(len(tokenizer_classifier))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
        model_vae.classifier = model_classifier

        logger.info(f"Pre-trained Optimus is successfully loaded: {output_full_dir}")
        model_vae.to(args.device)

        #sampling_importance_resampling(model_vae, tokenizer_decoder, args.sample_fn, args)
        sampling_importance_resampling_new(model_vae, tokenizer_decoder, args.sample_fn, args)

    global_step = 0
    if args.do_train_a_layer:
        if args.local_rank not in [-1, 0]:
            torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache

        train_dataloader = build_dataload_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False, shuffle=True)

        if args.local_rank == 0:
            torch.distributed.barrier()

        global_step, tr_loss, optimizer = train_p_az(args, train_dataloader, model_vae, tokenizer_encoder, tokenizer_decoder, tokenizer_classifier, table_name)  # 
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

    results = {}
    if args.do_eval_a_layer:
        if global_step == 0:
            global_step = args.gloabl_step_eval

        #output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(global_step))  # 
        #output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(global_step))
        #output_full_dir    = os.path.join(args.checkpoint_dir, 'checkpoint-full-{}'.format(global_step))
        #checkpoint_dir = [output_encoder_dir, output_decoder_dir, output_full_dir]

        #logger.info("Evaluate the following checkpoint: %s", checkpoint_dir[-1])
        #global_step = checkpoint_dir[-1].split('-')[-1] if len(checkpoint_dir) > 1 else ""

        #checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'))
        #model_vae.load_state_dict(checkpoint['model_state_dict'])  #  ======================
        #logger.info(f"Pre-trained Optimus is successfully loaded: {output_full_dir}")
        #model_vae.to(args.device)

        result = evaluate_a_layer(args, model_vae, tokenizer_encoder, tokenizer_decoder)
        result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
        results.update(result)

        output_eval_file = os.path.join(args.output_dir, "eval_vae_a_layer_results_test.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(results.keys()):
                logger.info("%s = %s", key, str(results[key]))
                writer.write("%s = %s\n" % (key, str(results[key])))
        logger.info(f"The testing results are successfully saved: {output_eval_file}")


    global_step = 0
    if args.do_train_gan:
        if args.local_rank not in [-1, 0]:
            torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache

        train_dataloader = build_dataload_and_cache_examples(args, [tokenizer_encoder, tokenizer_decoder], evaluate=False, shuffle=True)

        if args.local_rank == 0:
            torch.distributed.barrier()

        senti_classifier, categ_classifier = None, None
        #if args.continuous_eval_prior:
        if args.eval_gender:
            senti_classifier = classifier_model_class.from_pretrained('/mnt/efs/fs2/hzt/causal/Optimus/outputs_gender/classifiers/profession_full_classifier_weighted_01_09_lr2e5/checkpoint-2200')
            senti_classifier.to(args.device)
            categ_classifier = classifier_model_class.from_pretrained('/mnt/efs/fs2/hzt/causal/Optimus/outputs_gender/classifiers/gender_full_classifier_weighted_01_09_lr2e5/checkpoint-3000')
            categ_classifier.to(args.device)
        else:
            senti_classifier = classifier_model_class.from_pretrained('/mnt/efs/fs2/hzt/causal/Optimus/outputs/classifier/yelp_sentiment_unbias_large_gpt2_5ep/checkpoint-112000')
            senti_classifier.to(args.device)
            categ_classifier = classifier_model_class.from_pretrained('/mnt/efs/fs2/hzt/causal/Optimus/outputs/classifier/yelp_category_unbias_large_gpt2_5ep/checkpoint-78000')
            categ_classifier.to(args.device)

        global_step, tr_loss, optimizer = train_latent_gan(args, train_dataloader, model_vae, tokenizer_encoder, tokenizer_decoder, tokenizer_classifier, table_name,
                                                           senti_classifier, categ_classifier)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)



    ###############################
    ##  Evaluate the reconstruction loss for each checkpoints;
    ## This is used in studying two different latent vector injection schemes
    #results = {}
    #if args.do_eval_rec and args.local_rank in [-1, 0]:
    #    if global_step == 0:
    #        global_step = args.gloabl_step_eval
    #        # eval_steps = range(500, 13500, 500)
    #        # eval_steps = range(1000, 2000, 500)
    #        eval_steps = range(2000, 32000, 2000)

    #    checkpoints = []
    #    for e in eval_steps:
    #        output_encoder_dir = os.path.join(args.output_dir, 'checkpoint-encoder-{}'.format(e))
    #        output_decoder_dir = os.path.join(args.output_dir, 'checkpoint-decoder-{}'.format(e))
    #        checkpoints.append([output_encoder_dir, output_decoder_dir])


    #    logger.info("Evaluate the following checkpoints: %s", checkpoints)
    #    for checkpoint in checkpoints:
    #        global_step = checkpoint[0].split('-')[-1] if len(checkpoints) > 1 else ""

    #        model_encoder = encoder_model_class.from_pretrained(checkpoint[0], latent_size=args.latent_size)
    #        model_encoder.to(args.device)

    #        model_decoder = decoder_model_class.from_pretrained(checkpoint[1])
    #        model_decoder.to(args.device)

    #        model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args).to(args.device)

    #        result = evaluate_rec(args, model_vae, tokenizer_encoder, tokenizer_decoder, table_name, prefix=global_step, subset='test')
    #        result = dict((k + '_test_{}'.format(global_step), v) for k, v in result.items())
    #        results.update(result)

    #        result = evaluate_rec(args, model_vae, tokenizer_encoder, tokenizer_decoder, table_name, prefix=global_step, subset='train')
    #        result = dict((k + '_train_{}'.format(global_step), v) for k, v in result.items())
    #        results.update(result)

    #        # pdb.set_trace()

    #    output_eval_file = os.path.join(args.output_dir, "eval_rec_results.txt")
    #    with open(output_eval_file, "w") as writer:
    #        logger.info("***** Eval results *****")
    #        for key in sorted(results.keys()):
    #            logger.info("%s = %s", key, str(results[key]))
    #            writer.write("%s = %s\n" % (key, str(results[key])))
    #    logger.info(f"The testing results are successfully saved: {output_eval_file}")


    return results


if __name__ == "__main__":
    main()
