# coding=utf-8
import csv
import argparse
import logging
from argparse import ArgumentParser
import os
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Dataset
from tqdm import tqdm, trange
from transformers import AdamW
import fasttext
from metric import StyleTransfertMetric
import json

try:
    from transformers import (get_linear_schedule_with_warmup)
except:
    from transformers import WarmupLinearSchedule as get_linear_schedule_with_warmup
from models import *
from tensorboardX import SummaryWriter
from transformers import BertTokenizer
import copy
from transformers import *
from models_transfert_style import *


def clean_sentence(lines):
    filter_words = ['[SEP]', '[PAD]']
    cleaned_lines = []
    for line in lines:
        line_splitted = line.replace('\n', '').split(' ')
        line_splitted_ = []
        for word in line_splitted:
            if word not in filter_words:
                line_splitted_.append(word)
        cleaned_lines.append(line_splitted_)
    return cleaned_lines


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", default=16,type=int, help="random seed for initialization")
    parser.add_argument("--data_name", default='style_emb_lambda_5_baseline', help="random seed for initialization")
    parser.add_argument("--sentences", default='sentences_new', help="random seed for initialization")
    parser.add_argument("--lm_folder", default='sentences_new', help="random seed for initialization")
    parser.add_argument("--fastText_folder", default='sentences_new', help="random seed for initialization")
    args = parser.parse_args()
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.max_length = 43
    if 'gender' in args.lm_folder:
        args.max_length = 150

    data_name = args.data_name
    print(args.lm_folder)
    lm = GPT2LMHeadModel.from_pretrained(args.lm_folder).to(args.device)
    tokenizer = GPT2Tokenizer.from_pretrained(args.lm_folder)
    lm.eval()
    style_classifier = fasttext.load_model('{}/{}.bin'.format(args.fastText_folder, 'fastText'))

    for flip in [True, False]:
        with open('{}/{}/gen_{}.txt'.format(args.sentences, data_name, flip), 'r') as file:
            lines_gen = file.readlines()
        lines_gen = clean_sentence(lines_gen)
        with open('{}/{}/golden_{}.txt'.format(args.sentences, data_name, flip), 'r') as file:
            lines_golden = file.readlines()
        lines_golden = clean_sentence(lines_golden)

        with open('{}/{}/label_gen_{}.txt'.format(args.sentences, data_name, flip), 'r') as file:
            labels = file.readlines()

        labels = [int(i.replace('\n', '')) for i in labels]

        metric = StyleTransfertMetric(args, use_w_overlap=True, use_style_accuracy=True, use_ppl=True,
                                      style_classifier=style_classifier, lm=lm, tokenizer=tokenizer)

        logger.info("***** Running BLEU evaluation *****")
        w_bleu = metric.compute_blue_score(lines_gen, lines_golden)

        logger.info("***** Running Cosinus Similarity evaluation *****")
        w_cosinus = metric.compute_cosinus_similarity(lines_gen, lines_golden)

        logger.info("***** Running Overlap evaluation *****")
        w_overlap = metric.compute_w_overlap(lines_gen, lines_golden)

        logger.info("***** Running style transfert evaluation *****")
        style_accuracy = metric.compute_style_accuracy(lines_gen, lines_golden, labels)

        logger.info("***** Running ppl evaluation *****")
        ppl = metric.compute_ppl(lines_gen, lines_golden)

        with open('category_evaluation/results_{}_{}.txt'.format(data_name, flip), 'w') as file:
            file.write('{}:\t{}\n'.format('w_bleu', w_bleu))
            file.write('{}:\t{}\n'.format('w_cosinus', w_cosinus))
            file.write('{}:\t{}\n'.format('w_overlap', w_overlap))
            file.write('{}:\t{}\n'.format('style_accuracy', style_accuracy))
            file.write('{}:\t{}\n'.format('ppl', ppl))
