# coding=utf-8
import csv
import argparse
import logging
import os
import random
import numpy as np
import torch
from argparse import ArgumentParser
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Dataset
from tqdm import tqdm, trange
# import fasttext
from transformers import GPT2Tokenizer, GPT2LMHeadModel

from metric import StyleTransfertMetric
import json

try:
    from transformers import (get_linear_schedule_with_warmup)
except:
    from transformers import WarmupLinearSchedule as get_linear_schedule_with_warmup
from models import *
from tensorboardX import SummaryWriter
from transformers import BertTokenizer

logger = logging.getLogger(__name__)


class TextDataset(Dataset):
    def __init__(self, args, dev):
        logger.info("Loading dataset {}".format('Validation' if dev else 'Train'))
        suffix = 'dev' if dev else 'train'

        with open('data/data_tensor_{}/tensor_sentiment.{}.1'.format(args.dataset, suffix), 'r') as file:
            csv_reader = csv.reader(file, delimiter=',')
            lines_pos = [[int(index) for index in line] for line in csv_reader]
        with open('data/data_tensor_{}/tensor_sentiment.{}.0'.format(args.dataset, suffix), 'r') as file:
            csv_reader = csv.reader(file, delimiter=',')
            lines_neg = [[int(index) for index in line] for line in csv_reader]

        labels = [1] * len(lines_pos) + [0] * len(lines_neg)
        lines = lines_pos + lines_neg
        random.seed(42)
        random.shuffle(labels)
        random.seed(42)
        random.shuffle(lines)
        self.lines = lines[:args.filter]
        self.label = labels[:args.filter]

    def __len__(self):
        return len(self.label)

    def __getitem__(self, item):
        return {'line': torch.tensor(self.lines[item], dtype=torch.long),
                'label': torch.tensor(self.label[item], dtype=torch.long)}


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)


def test(args, train_dataset, eval_dataset, model, metric):
    train_sampler = SequentialSampler(train_dataset)
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, batch_size=args.batch_size, drop_last=True)
    logger.info("***** Training *****")
    logger.info("  Num examples = %5d", len(train_dataset))
    logger.info("  Batch size = %d", args.batch_size)
    styles_ = []
    labels_ = []
    i = 0
    for batch in tqdm(train_dataloader, desc="Computing Style Vector"):
        inputs = batch['line'].to(args.device)
        labels = batch['label'].to(args.device)
        i += 1
        with torch.no_grad():
            model = model.eval()
            style = model.predict_style_vector(inputs)
            styles_.append(style.cpu())
            labels_.append(labels.cpu())

    del inputs
    del labels
    with torch.no_grad():
        labels = torch.cat(labels_)
        styles = torch.cat(styles_, dim=1)

    # Compute style vectors
    style_0 = torch.sum(styles[:, labels == 0, :], dim=1)
    style_1 = torch.sum(styles[:, labels == 1, :], dim=1)

    style_1 = style_1.to(args.device)
    style_0 = style_0.to(args.device)
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, drop_last=True)

    # Eval!
    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %5d", len(eval_dataset))
    logger.info("  Batch size = %d", args.batch_size)
    nb_eval_steps = 0
    sentences_generated_ = []
    sentences_golden_ = []
    labels_ = []
    model.eval()
    j = 0
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        j += 1
        inputs = batch['line'].to(args.device)
        labels = batch['label'].to(args.device)

        # TODO : test both
        style = [style_0.unsqueeze(1) if label == 1 else style_1.unsqueeze(1) for label in labels.tolist()]
        style = torch.cat(style, dim=1)
        with torch.no_grad():
            model = model.eval()
            dict_loss, outputs = model.evaluate_for_style_transfert(inputs, style)

            # decode the sentences etc....
            outputs = outputs.tolist()
            inputs = inputs.tolist()
            sentences_generated = [args.tokenizer.decode(output) for output in outputs]
            sentences_golden = [args.tokenizer.decode(input) for input in inputs]

            labels_ += labels.tolist()
            sentences_generated_ += sentences_generated
            # print(sentences_generated_)
            sentences_golden_ += sentences_golden

    ppl = metric.compute_ppl(sentences_generated_, sentences_golden_)
    style_accuracy = (0,0)
    if metric.use_style_accuracy:
        style_accuracy = metric.compute_style_accuracy(sentences_generated_, sentences_golden_, labels)
    w_overlap = metric.compute_w_overlap(sentences_generated_, sentences_golden_)

    f = open(os.path.join(args.path, 'metric_evaluation.txt'), "w")
    f.write('Evaluation for disantanglement latent space: \n')
    f.write('ppl : transfet\t:{}\n'.format(ppl[0]))
    f.write('ppl : golden\t:{}\n'.format(ppl[1]))
    f.write('style_accuracy : transfet\t:{}\n'.format(style_accuracy[0]))
    f.write('style_accuracy : golden\t:{}\n'.format(style_accuracy[1]))
    f.write('w_overlap :\t:{}\n'.format(w_overlap))
    f.close()
    logger.info("  Finished for metrics")


    sentences_golden = metric.remove_pad_and_sep(sentences_golden_)
    sentences_generated = metric.remove_pad_and_sep(sentences_generated_)
    f = open(os.path.join(args.path, 'sentences.txt'), "w")
    for index in tqdm(range(min(100, len(sentences_generated))), desc='Sentences'):
        f.write('G:\t {}\n'.format(sentences_golden[index]))
        f.write('T:\t {}\n'.format(sentences_generated[index]))
    f.close()
    logger.info("  Finished for sentences")


def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--dataset", default='yelp', type=str, help="The input training data file (a text file).")
    parser.add_argument("--number_of_sentence", default=10, type=int,
                        help="The input training data file (a text file).")
    parser.add_argument("--batch_size", default=10, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument("--max_length", default=43, type=int, help="Linear warmup over warmup_steps.")
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
    parser.add_argument("--filter", type=int, default=-1, help="random seed for initialization")

    # Architecture
    parser.add_argument("--model", default='baseline_desantanglement_h', help="random seed for initialization")
    parser.add_argument("--path", default='models/vanilla_seq2seq_2', help="loading from path")
    parser.add_argument("--checkpoints", default='checkpoint-660000', help="loading from path")
    parser.add_argument("--saving_result_file", default='saving_result_file.txt', help="loading from path")

    # Metrics
    parser.add_argument("--model_metrics", default='model_for_metric_evaluation', help="loading from path")

    args = parser.parse_args()
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    args.sos_token = tokenizer.sep_token_id
    args.number_of_tokens = tokenizer.vocab_size
    args.tokenizer = tokenizer
    args.padding_idx = tokenizer.pad_token_id
    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)
    # Set seed
    set_seed(args)

    logger.info("------------------------------------------ ")
    logger.info("model type = %s ", args.model)
    logger.info("------------------------------------------ ")

    args_model = ArgumentParser()
    # args_model = parser_model.parse_args()
    with open(os.path.join(args.path, args.checkpoints, 'training_args.txt'), 'r') as f:
        args_model.__dict__ = json.load(f)

    args_model.sos_token = tokenizer.sep_token_id
    args_model.batch_size = args.batch_size
    args_model.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args_model.number_of_tokens = tokenizer.vocab_size
    args_model.tokenizer = tokenizer
    args_model.padding_idx = tokenizer.pad_token_id

    if args.model == 'baseline_desantanglement':
        model = BaselineDisentanglement(args_model)
        weigths_non_loaded = model.proj_content[0].weight.tolist()
        model.load_state_dict(torch.load(os.path.join(args.path, args.checkpoints, 'model.pt'),
                                         map_location=torch.device(args.device)))
        assert weigths_non_loaded != model.proj_content[0].weight.tolist()
    else:
        model = VanillaSeq2seq(args_model)
        weigths_non_loaded = model.decoder.out.weight.tolist()
        model.load_state_dict(torch.load(os.path.join(args.path, args.checkpoints, 'model.pt'),
                                         map_location=torch.device(args.device)))
        assert weigths_non_loaded != model.decoder.out.weight.tolist()
        args_model.content_dim = args_model.hidden_dim

    model.to(args.device)
    model.eval()
    # Metrics:
    logger.info("------------------------------------------ ")
    logger.info("Initializing Metric")
    logger.info("------------------------------------------ ")

    # Load pre-trained model (weights)
    lm = GPT2LMHeadModel.from_pretrained(os.path.join(args.model_metrics, 'gpt_2_for_{}'.format(args.dataset))).to(
        args.device)
    tokenizer = GPT2Tokenizer.from_pretrained(os.path.join(args.model_metrics, 'gpt_2_for_{}'.format(args.dataset)))
    lm.eval()
    # style_classifier = fasttext.load_model(
    #     '{}.bin'.format(os.path.join(args.model_metrics, '{}_fastText'.format(args.dataset), 'fastText')))
    style_classifier = None
    metric = StyleTransfertMetric(args, use_w_overlap=True, use_style_accuracy=False, use_ppl=True,
                                  style_classifier=style_classifier, lm=lm, tokenizer=tokenizer)

    logger.info("Test parameters %s", args)
    train_dataset = TextDataset(args, False)
    dev_dataset = TextDataset(args, True)
    test(args, train_dataset, dev_dataset, model, metric)
    logger.info(" Test Over ")


if __name__ == "__main__":
    main()
