# coding=utf-8
import csv
import sys

import argparse
from argparse import ArgumentParser
import os
import numpy as np
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Dataset
from tqdm import tqdm, trange
import json

try:
    from transformers import (get_linear_schedule_with_warmup)
except:
    from transformers import WarmupLinearSchedule as get_linear_schedule_with_warmup
from models import *
from tensorboardX import SummaryWriter
from transformers import BertTokenizer
import copy
from transformers import *
from models_transfert_style import *

logger = logging.getLogger(__name__)


class TextDataset(Dataset):
    def __init__(self, args, dev, reny_ds=False):
        max_length = 43
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        logger.info("Loading dataset {}".format('Validation' if dev else 'Train'))
        with open('sentiment.test.0.fader', 'r') as file:
            lines_pos = file.readlines()
        lines_pos = [tokenizer.encode(i.split('\t')[0]) for i in lines_pos]
        lines_pos = [i + [tokenizer.pad_token_id] * (max_length - len(i)) for i in lines_pos]
        with open('sentiment.test.1.fader', 'r') as file:
            lines_neg = file.readlines()
        lines_neg = [tokenizer.encode(i.split('\t')[0]) for i in lines_neg]
        lines_neg = [i + [tokenizer.pad_token_id] * (max_length - len(i)) for i in lines_neg]

        labels = [1] * len(lines_pos) + [0] * len(lines_neg)
        lines = lines_pos + lines_neg
        temp = list(zip(labels, lines))
        labels, lines = zip(*temp)

        self.lines = lines
        self.label = labels

    def __len__(self):
        return len(self.label)

    def __getitem__(self, item):
        return {'line': torch.tensor(self.lines[item], dtype=torch.long),
                'label': torch.tensor(self.label[item], dtype=torch.long)}


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)


def clean_seq(args, outputs):
    pad_token = args.tokenizer.pad_token_id
    sep = args.tokenizer.sep_token_id
    cleaned_seq = []
    for seq in outputs:
        try:
            first_index_not_null = [i for i, x in enumerate(seq) if x == sep][1]
        except:
            first_index_not_null = 100
        seq_cleaned = []
        for i, x in enumerate(seq):
            if i < first_index_not_null:
                seq_cleaned.append(x)
            else:
                seq_cleaned.append(pad_token)
        cleaned_seq.append(seq_cleaned)
    return cleaned_seq


def generate_sentences(args, eval_dataset, model, flip_label):
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset, shuffle=False, sampler=eval_sampler, batch_size=args.batch_size, drop_last=False)

    # Eval!
    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %5d", len(eval_dataset))
    logger.info("  Batch size = %d", args.batch_size)
    nb_eval_steps = 0
    sentences_generated_ = []
    sentences_golden_ = []
    labels_ = []
    model.eval()
    j = 0
    style_pos, style_neg = 0, 0

    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        j += 1
        inputs = batch['line'].to(args.device)
        logger.info(inputs.size())
        labels = batch['label'].to(args.device)
        logger.info(labels.size())
        with torch.no_grad():
            model = model.eval()
            outputs = model.forward_transfert(inputs, 1 - labels if flip_label else labels, style_pos, style_neg)
            outputs = clean_seq(args, outputs.tolist())
            sentences_generated_ += outputs
            sentences_golden_ += inputs.tolist()
            labels_ += labels.tolist()
            nb_eval_steps += 1

    sentences_generated_ = [args.tokenizer.decode(output, skip_special_tokens=True) for output in sentences_generated_]
    sentences_golden_ = [args.tokenizer.decode(output, skip_special_tokens=True) for output in sentences_golden_]

    with open(os.path.join(args.sentences, args.suffix, 'gen_{}.txt'.format(flip_label)), 'w') as file:
        file.writelines(['{}\n'.format(str(i)) for i in sentences_generated_])

    with open(os.path.join(args.sentences, args.suffix, 'label_gen_{}.txt'.format(flip_label)), 'w') as file:
        labels_w = [1 - i for i in labels_] if flip_label else labels_
        file.writelines(['{}\n'.format(str(i)) for i in labels_w])

    with open(os.path.join(args.sentences, args.suffix, 'golden_{}.txt'.format(flip_label)), 'w') as file:
        file.writelines(['{}\n'.format(str(i)) for i in sentences_golden_])

    with open(os.path.join(args.sentences, args.suffix, 'label_golden_{}.txt'.format(flip_label)), 'w') as file:
        labels_w = labels_
        file.writelines(['{}\n'.format(str(i)) for i in labels_w])


def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--dataset", default='yelp', type=str, help="The input training data file (a text file).")
    parser.add_argument("--suffix", default='yelp', type=str, help="The input training data file (a text file).")
    parser.add_argument("--sentences", default='sentences_new/', type=str,
                        help="The input training data file (a text file).")
    parser.add_argument("--filter", default=100, type=int, help="The input training data file (a text file).")
    parser.add_argument("--output_dir", default='debug',
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--batch_size", default=1, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument("--max_length", default=43, type=int, help="Linear warmup over warmup_steps.")
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
    parser.add_argument("--save_step", type=int, default=1000, help="random seed for initialization")
    parser.add_argument("--num_train_epochs", type=int, default=1000, help="random seed for initialization")
    parser.add_argument("--warmup_steps", type=int, default=0, help="random seed for initialization")
    parser.add_argument("--learning_rate", type=float, default=0.001, help="random seed for initialization")
    parser.add_argument("--adam_epsilon", type=float, default=0.001, help="random seed for initialization")

    # Architecture Seq2Seq
    parser.add_argument("--model", default='style_emb', help="random seed for initialization")
    parser.add_argument("--model_path_to_load", default='checkpoint-10000',
                        help="random seed for initialization")
    parser.add_argument("--saving_result_file", default='test_transfert.txt', help="loading from path")

    # Classifier
    parser.add_argument("--path_classifier", default='classifier_vanilla_seq2seq_2_for_desantaglement_results',
                        help="loading from path")
    parser.add_argument("--checkpoints_path_classifier", default='checkpoint-173400', help="loading from path")

    # Architecture
    parser.add_argument("--style_dim", type=int, default=8, help="random seed for initialization")
    parser.add_argument("--content_dim", type=int, default=256, help="random seed for initialization")
    parser.add_argument("--number_of_layers", type=int, default=2, help="random seed for initialization")
    parser.add_argument("--hidden_dim", type=int, default=256, help="random seed for initialization")
    parser.add_argument("--dec_hidden_dim", type=int, default=136, help="random seed for initialization")
    parser.add_argument("--dropout", type=float, default=0.5, help="random seed for initialization")
    parser.add_argument("--number_of_styles", type=int, default=2, help="random seed for initialization")
    parser.add_argument("--mul_style", type=int, default=10, help="random seed for initialization")
    parser.add_argument("--adv_style", type=int, default=1, help="random seed for initialization")
    parser.add_argument("--mul_mi", type=float, default=1, help="random seed for initialization")
    parser.add_argument("--alpha", type=float, default=1.5, help="random seed for initialization")
    parser.add_argument("--ema_beta", type=float, default=0.99, help="random seed for initialization")
    parser.add_argument("--not_use_ema", action="store_true", help="random seed for initialization")
    parser.add_argument("--no_reny", action="store_true", help="random seed for initialization")
    parser.add_argument("--reny_training", type=int, default=2, help="random seed for initialization")
    parser.add_argument("--number_of_training_encoder", type=int, default=1, help="random seed for initialization")
    parser.add_argument("--add_noise", action="store_true", help="random seed for initialization")
    parser.add_argument("--noise_p", type=float, default=0.1, help="random seed for initialization")
    parser.add_argument("--number_of_perm", type=int, default=3, help="random seed for initialization")
    parser.add_argument("--alternative_hs", action="store_true", help="random seed for initialization")
    parser.add_argument("--no_minimization_of_mi_training", action="store_true")
    parser.add_argument("--special_clement", action="store_true")
    parser.add_argument("--use_gender", action="store_true")
    parser.add_argument("--complex_proj_content", action="store_true")

    # What to do
    parser.add_argument("--do_eval", action='store_true', help="loading from path")
    parser.add_argument("--do_train_classifer", action='store_true', help="loading from path")
    parser.add_argument("--do_test_reconstruction", action='store_true', help="loading from path")
    parser.add_argument("--do_test_transfer", action='store_true', help="loading from path")
    parser.add_argument("--use_complex_classifier", action='store_true', help="loading from path")

    # Metrics
    parser.add_argument("--model_metrics", default='model_for_metric_evaluation', help="loading from path")

    args = parser.parse_args()
    args.device = torch.device("cpu")  # torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    args.sos_token = tokenizer.sep_token_id
    args.number_of_tokens = tokenizer.vocab_size
    args.tokenizer = tokenizer
    args.use_complex_classifier = True
    args.padding_idx = tokenizer.pad_token_id
    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)
    # Set seed
    set_seed(args)
    os.makedirs(os.path.join(args.sentences, args.suffix), exist_ok=True)
    os.makedirs(os.path.join('results_desantanglement', args.suffix), exist_ok=True)
    args_model = ArgumentParser()
    # args_model = parser_model.parse_args()
    with open(os.path.join(args.model_path_to_load, 'training_args.txt'), 'r') as f:
        args_model.__dict__ = json.load(f)

    args_model.sos_token = tokenizer.sep_token_id
    args_model.device = torch.device("cpu")  # torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args_model.number_of_tokens = tokenizer.vocab_size
    args_model.tokenizer = tokenizer
    args_model.batch_size = args.batch_size
    args_model.padding_idx = tokenizer.pad_token_id

    logger.info("------------------------------------------ ")
    logger.info("model type = %s ", args.model)
    logger.info("------------------------------------------ ")

    try:
        if args_model.use_complex_classifier is None:
            args_model.use_complex_classifier = False
    except:
        args_model.use_complex_classifier = False
    model = StyleEmdedding(args_model, None)

    weight_pr = model.encoder.embedding.weight.data.tolist()
    model.load_state_dict(torch.load(os.path.join(args.model_path_to_load, 'model.pt'),
                                     map_location=torch.device(args.device)))
    assert weight_pr != model.encoder.embedding.weight.data.tolist()
    args_model.content_dim = args_model.hidden_dim
    model.to(args.device)
    model.eval()

    logger.info("Model parameters %s", args)

    dev_dataset = TextDataset(args, True)

    generate_sentences(args, dev_dataset, model, True)
    logger.info(" Test Over ")

    logger.info(" Program Over ")


if __name__ == "__main__":
    main()
