# coding=utf-8
import csv
import argparse
import logging
import os
import random
import numpy as np
import torch
from argparse import ArgumentParser
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Dataset
from tqdm import tqdm, trange
# import fasttext
from transformers import GPT2Tokenizer, GPT2LMHeadModel

from metric import StyleTransfertMetric
import json

try:
    from transformers import (get_linear_schedule_with_warmup)
except:
    from transformers import WarmupLinearSchedule as get_linear_schedule_with_warmup
from models import *
from tensorboardX import SummaryWriter
from transformers import BertTokenizer

logger = logging.getLogger(__name__)


class TextDataset(Dataset):
    def __init__(self, args, dev):
        logger.info("Loading dataset {}".format('Validation' if dev else 'Train'))
        suffix = 'dev' if dev else 'train'

        try:

            with open('data/data_tensor_{}/tensor_sentiment.{}.1'.format(args.dataset, suffix), 'r') as file:
                csv_reader = csv.reader(file, delimiter=',')
                lines_pos = [[int(index) for index in line] for line in csv_reader]
            with open('data/data_tensor_{}/tensor_sentiment.{}.0'.format(args.dataset, suffix), 'r') as file:
                csv_reader = csv.reader(file, delimiter=',')
                lines_neg = [[int(index) for index in line] for line in csv_reader]
        except:
            try:
                with open('/data/DIT/data/data_tensor_{}/tensor_sentiment.{}.1'.format(args.dataset, suffix),
                          'r') as file:
                    csv_reader = csv.reader(file, delimiter=',')
                    lines_pos = [[int(index) for index in line] for line in csv_reader]
                with open('/data/DIT/data/data_tensor_{}/tensor_sentiment.{}.0'.format(args.dataset, suffix),
                          'r') as file:
                    csv_reader = csv.reader(file, delimiter=',')
                    lines_neg = [[int(index) for index in line] for line in csv_reader]
            except:
                with open(os.path.join('/USER/DIT',
                                       'data/data_tensor_{}/tensor_sentiment.{}.1'.format(args.dataset, suffix)),
                          'r') as file:
                    csv_reader = csv.reader(file, delimiter=',')
                    lines_pos = [[int(index) for index in line] for line in csv_reader]
                with open(os.path.join('/USER/DIT',
                                       'data/data_tensor_{}/tensor_sentiment.{}.0'.format(args.dataset, suffix)),
                          'r') as file:
                    csv_reader = csv.reader(file, delimiter=',')
                    lines_neg = [[int(index) for index in line] for line in csv_reader]

        labels = [1] * len(lines_pos) + [0] * len(lines_neg)
        lines = lines_pos + lines_neg
        random.seed(42)
        random.shuffle(labels)
        random.seed(42)
        random.shuffle(lines)
        self.lines = lines[:args.filter]
        self.label = labels[:args.filter]

    def __len__(self):
        return len(self.label)

    def __getitem__(self, item):
        return {'line': torch.tensor(self.lines[item], dtype=torch.long),
                'label': torch.tensor(self.label[item], dtype=torch.long)}


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)


def test(args, eval_dataset, model, metric):
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, drop_last=True)

    # Eval!
    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %5d", len(eval_dataset))
    logger.info("  Batch size = %d", args.batch_size)
    nb_eval_steps = 0
    sentences_generated_ = []
    sentences_golden_ = []
    labels_ = []
    model.eval()
    j = 0
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        j += 1
        inputs = batch['line'].to(args.device)
        labels = batch['label'].to(args.device)

        with torch.no_grad():
            model = model.eval()
            try :
                # reny
                dict_loss, outputs = model.forward_reconstruction(inputs)
            except :
                # the rest
                dict_loss, outputs = model(inputs, labels,0)
            sentences_generated_ += outputs.tolist()
            sentences_golden_ += inputs.tolist()
            labels_ += labels.tolist()
            nb_eval_steps += 1

    sentences_generated_ = [args.tokenizer.decode(output) for output in sentences_generated_]
    sentences_golden_ = [args.tokenizer.decode(output) for output in sentences_golden_]
    logger.info("***** Running ppl evaluation *****")
    ppl = metric.compute_ppl(sentences_generated_, sentences_golden_)
    style_accuracy = (0, 0)
    if metric.use_style_accuracy:
        logger.info("***** Running style transfert evaluation *****")
        style_accuracy = metric.compute_style_accuracy(sentences_generated_, sentences_golden_, labels)
    logger.info("***** Running BLEU evaluation *****")
    w_overlap = metric.compute_w_overlap(sentences_generated_, sentences_golden_)

    f = open(os.path.join(args.path, args.saving_result_file), "w")
    f.write('Evaluation for reconstruction: \n')
    f.write('ppl : transfet\t:{}\n'.format(ppl[0]))
    f.write('ppl : golden\t:{}\n'.format(ppl[1]))
    f.write('style_accuracy : transfet\t:{}\n'.format(style_accuracy[0]))
    f.write('style_accuracy : golden\t:{}\n'.format(style_accuracy[1]))
    f.write('w_overlap :\t:{}\n'.format(w_overlap))
    f.close()
    logger.info("  Finished for metrics")

    sentences_golden = metric.remove_pad_and_sep(sentences_golden_)
    sentences_generated = metric.remove_pad_and_sep(sentences_generated_)
    f = open(os.path.join(args.path, 'sentences.txt'), "w")
    for index in tqdm(range(min(100, len(sentences_generated))), desc='Sentences'):
        f.write('G:\t {}\n'.format(sentences_golden[index]))
        f.write('T:\t {}\n'.format(sentences_generated[index]))
    f.close()
    logger.info("  Finished for sentences")


def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--dataset", default='yelp', type=str, help="The input training data file (a text file).")
    parser.add_argument("--number_of_sentence", default=10, type=int,
                        help="The input training data file (a text file).")
    parser.add_argument("--batch_size", default=256, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument("--max_length", default=43, type=int, help="Linear warmup over warmup_steps.")
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
    parser.add_argument("--filter", type=int, default=-1, help="random seed for initialization")

    # Architecture
    parser.add_argument("--model", default='baseline_desantanglement_h', help="random seed for initialization")
    parser.add_argument("--path", default='/data/DIT/models/vanilla_seq2seq_h256', help="loading from path")
    parser.add_argument("--checkpoints", default='checkpoint-130000', help="loading from path")
    parser.add_argument("--saving_result_file", default='test_reconstruction.txt', help="loading from path")

    ## Architecture Details
    parser.add_argument("--style_dim", type=int, default=8, help="random seed for initialization")
    parser.add_argument("--content_dim", type=int, default=128, help="random seed for initialization")
    parser.add_argument("--number_of_layers", type=int, default=2, help="random seed for initialization")
    parser.add_argument("--hidden_dim", type=int, default=256, help="random seed for initialization")
    parser.add_argument("--dec_hidden_dim", type=int, default=136, help="random seed for initialization")
    parser.add_argument("--dropout", type=float, default=0.5, help="random seed for initialization")
    parser.add_argument("--number_of_styles", type=int, default=2, help="random seed for initialization")
    parser.add_argument("--mul_style", type=int, default=10, help="random seed for initialization")
    parser.add_argument("--adv_style", type=int, default=1, help="random seed for initialization")
    parser.add_argument("--mul_mi", type=float, default=1, help="random seed for initialization")
    parser.add_argument("--alpha", type=float, default=1.5, help="random seed for initialization")
    parser.add_argument("--ema_beta", type=float, default=0.99, help="random seed for initialization")
    parser.add_argument("--not_use_ema", action="store_true", help="random seed for initialization")
    parser.add_argument("--no_reny", action="store_true", help="random seed for initialization")
    parser.add_argument("--reny_training", type=int, default=2, help="random seed for initialization")
    parser.add_argument("--number_of_training_encoder", type=int, default=1, help="random seed for initialization")
    parser.add_argument("--add_noise", action="store_true", help="random seed for initialization")
    parser.add_argument("--noise_p", type=float, default=0.1, help="random seed for initialization")
    parser.add_argument("--number_of_perm", type=int, default=3, help="random seed for initialization")
    parser.add_argument("--alternative_hs", action="store_true", help="random seed for initialization")
    parser.add_argument("--no_minimization_of_mi_training", action="store_true")
    parser.add_argument("--special_clement", action="store_true")
    parser.add_argument("--complex_proj_content", action="store_true")
    parser.add_argument("--use_complex_classifier", action='store_true', help="loading from path")

    # Metrics
    parser.add_argument("--model_metrics", default='/data/DIT/model_for_metric_evaluation', help="loading from path")

    args = parser.parse_args()
    args.saving_result_file = '{}_{}'.format(args.checkpoints, args.saving_result_file)
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    args.sos_token = tokenizer.sep_token_id
    args.number_of_tokens = tokenizer.vocab_size
    args.tokenizer = tokenizer
    args.padding_idx = tokenizer.pad_token_id
    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)
    # Set seed
    set_seed(args)

    logger.info("------------------------------------------ ")
    logger.info("model type = %s ", args.model)
    logger.info("------------------------------------------ ")

    args_model = ArgumentParser()
    # args_model = parser_model.parse_args()
    with open(os.path.join(args.path, args.checkpoints, 'training_args.txt'), 'r') as f:
        args_model.__dict__ = json.load(f)

    args_model.sos_token = tokenizer.sep_token_id
    args_model.batch_size = args.batch_size
    args_model.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args_model.number_of_tokens = tokenizer.vocab_size
    args_model.tokenizer = tokenizer
    args_model.padding_idx = tokenizer.pad_token_id

    if args.model == 'reny':
        model = RenySeq2Seq(args, None)
        weigths_non_loaded = model.proj_content[0].weight.tolist()
        model.load_state_dict(torch.load(os.path.join(args.path, args.checkpoints, 'model.pt'),
                                         map_location=torch.device(args.device)))
        assert weigths_non_loaded != model.proj_content[0].weight.tolist()
    elif args.model == 'baseline_desantanglement':
        model = BaselineDisentanglement(args_model)
        weigths_non_loaded = model.proj_content[0].weight.tolist()
        model.load_state_dict(torch.load(os.path.join(args.path, args.checkpoints, 'model.pt'),
                                         map_location=torch.device(args.device)))
        assert weigths_non_loaded != model.proj_content[0].weight.tolist()
    else:
        model = VanillaSeq2seq(args_model)
        weigths_non_loaded = model.decoder.out.weight.tolist()
        model.load_state_dict(torch.load(os.path.join(args.path, args.checkpoints, 'model.pt'),
                                         map_location=torch.device(args.device)))
        assert weigths_non_loaded != model.decoder.out.weight.tolist()
        args_model.content_dim = args_model.hidden_dim
    model.to(args.device)
    model.eval()
    # Metrics:
    logger.info("------------------------------------------ ")
    logger.info("Initializing Metric")
    logger.info("------------------------------------------ ")

    # Load pre-trained model (weights)
    lm = GPT2LMHeadModel.from_pretrained(os.path.join(args.model_metrics, 'gpt_2_for_{}'.format(args.dataset))).to(
        args.device)
    tokenizer = GPT2Tokenizer.from_pretrained(os.path.join(args.model_metrics, 'gpt_2_for_{}'.format(args.dataset)))
    lm.eval()
    # style_classifier = fasttext.load_model(
    #     '{}.bin'.format(os.path.join(args.model_metrics, '{}_fastText'.format(args.dataset), 'fastText')))
    style_classifier = None
    metric = StyleTransfertMetric(args, use_w_overlap=True, use_style_accuracy=False, use_ppl=True,
                                  style_classifier=style_classifier, lm=lm, tokenizer=tokenizer)

    logger.info("Test parameters %s", args)
    dev_dataset = TextDataset(args, True)
    test(args, dev_dataset, model, metric)
    logger.info(" Test Over ")


if __name__ == "__main__":
    main()
