# coding=utf-8
import csv
import argparse
import logging
from argparse import ArgumentParser
import os
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Dataset
from tqdm import tqdm, trange
from transformers import AdamW
import fasttext
from metric import StyleTransfertMetric
import json

try:
    from transformers import (get_linear_schedule_with_warmup)
except:
    from transformers import WarmupLinearSchedule as get_linear_schedule_with_warmup
from models import *
from tensorboardX import SummaryWriter
from transformers import BertTokenizer
import copy
from transformers import *
from models_transfert_style import *


def clean_sentence(lines, tokenizer_bert):
    filter_words = ['[SEP]', '[PAD]']
    cleaned_lines = []
    for line in lines:
        line = tokenizer_bert.decode(tokenizer_bert.encode(line.replace('\n', '')), skip_special_tokens=True)
        line_splitted = line.replace('\n', '').split(' ')
        line_splitted_ = []
        for word in line_splitted:
            if word not in filter_words:
                line_splitted_.append(word)
        cleaned_lines.append(line_splitted_)
    return cleaned_lines


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", default=16, type=int, help="random seed for initialization")
    parser.add_argument("--sentences", default='.', help="random seed for initialization")
    parser.add_argument("--lm_folder",
                        default='//Desktop/DIT/model_for_metric_evaluation_sentiment/gpt_2_for_yelp',
                        help="random seed for initialization")
    parser.add_argument("--fastText_folder",
                        default='//Desktop/DIT/model_for_metric_evaluation_sentiment/yelp_fastText',
                        help="random seed for initialization")
    args = parser.parse_args()
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.max_length = 43
    if 'gender' in args.lm_folder:
        args.max_length = 150

    tokenizer_bert = BertTokenizer.from_pretrained('bert-base-uncased')

    print(args.lm_folder)
    print(os.listdir(args.lm_folder))
    tokenizer = GPT2Tokenizer.from_pretrained(args.lm_folder)
    lm = GPT2LMHeadModel.from_pretrained(args.lm_folder).to(args.device)
    lm.eval()
    style_classifier = fasttext.load_model('{}/{}.bin'.format(args.fastText_folder, 'fastText'))
    logger.info("***** Metric *****")
    metric = StyleTransfertMetric(args, use_w_overlap=True, use_style_accuracy=True, use_ppl=True,
                                  style_classifier=style_classifier, lm=lm, tokenizer=tokenizer,
                                  not_use_vector=True)  # TODO
    logger.info("***** Metric *****")

    list_names_models = []
    for mul in [0.1, 1]:
        for MODEL_TYPE in ['no_reny', 'baseline', 'reny_1_3', 'reny_1_5', 'reny_1_8']:
            list_names_models.append('{}_{}'.format(MODEL_TYPE, mul))
    for data_name in tqdm(
            list_names_models + ['fader', 'human', 'label', 'mit', 'multi_decoder', 'orgin', 'retrieval', 'rule_base']):
        with open('//Desktop/DIT/data/yelp_other/sentiment.test.0.{}'.format(
                data_name),
                'r') as file:
            lines_neg = file.readlines()
        with open('//Desktop/DIT/data/yelp_other/sentiment.test.1.{}'.format(
                data_name),
                'r') as file:
            lines_pos = file.readlines()
        # print(lines_neg)
        # print(lines_pos)

        lines_neg_ = clean_sentence([i.split('\t')[0].replace('\n', '') for i in lines_neg], tokenizer_bert)
        lines_neg_golder = clean_sentence([i.split('\t')[1].replace('\n', '') for i in lines_neg], tokenizer_bert)

        lines_pos_ = clean_sentence([i.split('\t')[0].replace('\n', '') for i in lines_pos], tokenizer_bert)
        lines_pos_golder = clean_sentence([i.split('\t')[1].replace('\n', '') for i in lines_pos], tokenizer_bert)

        labels = [0] * len(lines_neg) + [1] * len(lines_pos)
        lines = lines_neg_ + lines_pos_
        lines_golden = lines_neg_golder + lines_pos_golder

        logger.info("***** Running style transfert evaluation *****")
        style_accuracy = metric.compute_style_accuracy(lines_neg_ + lines_pos_, lines_neg_golder + lines_pos_golder,
                                                       labels)
        w_bleu, w_cosinus, w_overlap, ppl = 0, 0, 0, 0
        if True:
            logger.info("***** Running BLEU evaluation *****")
            w_bleu = metric.compute_blue_score(lines_golden, lines)

            logger.info("***** Running Cosinus Similarity evaluation *****")
            w_cosinus = metric.compute_cosinus_similarity(lines_neg_ + lines_pos_, lines_neg_golder + lines_pos_golder)

            logger.info("***** Running Overlap evaluation *****")
            w_overlap = metric.compute_w_overlap(lines_neg_ + lines_pos_, lines_neg_golder + lines_pos_golder)

            logger.info("***** Running ppl evaluation *****")
            ppl = metric.compute_ppl(lines_neg_ + lines_pos_, lines_neg_golder + lines_pos_golder)

        with open('//Desktop/DIT/results_related_work.txt', 'a') as file:
            file.write('{}:\t{}\n'.format('name', data_name))
            file.write('{}:\t{}\n'.format('w_bleu', w_bleu))
            file.write('{}:\t{}\n'.format('w_cosinus', w_cosinus))
            # file.write('{}:\t{}\n'.format('w_overlap', w_overlap))
            file.write('{}:\t{}\n'.format('style_accuracy', style_accuracy))
            file.write('{}:\t{}\n'.format('ppl', ppl))
