# export CORENLP_HOME=/gpfsdswork/projects/rech/qsq/uwi62ct/SummEval/evaluation/summ_eval/stanford-corenlp-full-2018-10-05/
# https://github.com/bheinzerling/pyrouge/issues/8
# split https://arxiv.org/pdf/1806.06422.pdf -> 18
#
import os

import itertools
import csv
import argparse
import logging
import os
import json
from tqdm import tqdm
import copy
import utils
import random
import numpy as np
import torch
from readers import *
from metric import *
from transformers import AutoTokenizer
import scipy.stats as stats
from utils import set_seed
import logging
logger = logging.getLogger(__name__)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--seed", default=1531, type=int, help="The input training data file (a text file).")
    parser.add_argument("--chunk_id", default=8, type=int, help="The input training data file (a text file).")
    parser.add_argument("--metric", default='bary_score', type=str, choices=available_metric.keys())
    parser.add_argument("--suffix", default='info_score', type=str)
    parser.add_argument("--use_abs", action='store_true')
    parser.add_argument("--invert_support", action='store_true')
    parser.add_argument("--use_idf_weights", action='store_true')
    parser.add_argument("--use_lm", action='store_true')
    parser.add_argument("--not_use_div", action='store_true')
    parser.add_argument("--wasserstein_data", action='store_true')

    # Setup logging
    args = parser.parse_args()
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)
    # Set seed
    set_seed(args)
    args.logger = logger

    metric = available_metric[args.metric]
    if args.wasserstein_data:
        DATA_PATH = 'data/cnn_bis_wasserstein_new/cnn_abs_formated_{}.json' if args.use_abs else 'data/cnn_bis_wasserstein_new/cnn_ext_formated_{}.json'
        NUMBER_OF_CHUNKS = 10
    else:
        DATA_PATH = 'data/cnn_bis/cnn_abs_formated_{}.json' if args.use_abs else 'data/cnn_bis/cnn_ext_formated_{}.json'
        NUMBER_OF_CHUNKS = 10
    metric.use_idf_weights = args.use_idf_weights
    assert args.chunk_id < NUMBER_OF_CHUNKS
    assert args.chunk_id >= 0
    try:
        metric.invert_support = args.invert_support
    except:
        pass
    with open(DATA_PATH.format(args.chunk_id), 'r') as file:
        data = json.load(file)
    if args.metric == 'bert_score':
        model_name = 'bert-base-uncased'
        if os.path.exists('/gpfswork/rech/qsq/uwi62ct/transformers_models/{}/'.format(model_name)):
            print('TOTTOTOTOT')
            metric = BertScoreMetric(model_type='/gpfswork/rech/qsq/uwi62ct/transformers_models/{}/'.format(model_name))
        else:
            metric = BertScoreMetric(model_type='bert-base-uncased')

    model_name = 'bert-base-uncased'
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
    except:
        tokenizer = AutoTokenizer.from_pretrained(
            '/gpfswork/rech/qsq/uwi62ct/transformers_models/bert-base-uncased/')
    l_refs = []
    l_cands = []
    f_candidates = []
    f_references = []
    for i in range(NUMBER_OF_CHUNKS):
        if args.wasserstein_data:
            DATA_PATH = 'data/cnn_bis_wasserstein_new/cnn_abs_formated_{}.json' if args.use_abs else 'data/cnn_bis_wasserstein_new/cnn_ext_formated_{}.json'
        else:
            DATA_PATH = 'data/cnn_bis/cnn_abs_formated_{}.json' if args.use_abs else 'data/cnn_bis/cnn_ext_formated_{}.json'
        with open(DATA_PATH.format(i), 'r') as file:
            data_idf = json.load(file)
        references = [value['references_sentences'] for _, value in data_idf.items()]
        for ref in references:
            l_refs.append(tokenizer(ref)['input_ids'])

        candidates = [[sub_value['generated_sentence'] for _, sub_value in value['system'].items()] for _, value in
                      data_idf.items()]
        for cand in candidates:
            l_cands.append(tokenizer(cand)['input_ids'])

        f_candidates += sum(candidates, [])
        f_references += sum(references, [])

    idf_dict_ref = utils.ref_list_to_idf(sum(l_refs, [])) if args.metric != 'mover_score' else metric.get_idf_dict(
        f_references)
    idf_dict_hyp = utils.ref_list_to_idf(sum(l_cands, [])) if args.metric != 'mover_score' else metric.get_idf_dict(
        f_candidates)
    logger.info('Data Length {}'.format(len(data)))
    logger.info('Chunk id {}'.format(args.chunk_id))

    if args.metric in ['bary_score']:
        metric.prepare_idfs(f_references, f_candidates)

    for key, value in tqdm(data.items(), "Progress In Chunk"):
        references = value["references_sentences"]
        for system_name, value_system in value["system"].items():
            candidate_sentences = [value_system["generated_sentence"]] * len(references)
            if args.metric in ['info_score', 'info_mover_score', 'mover_pluplu', 'wasserstein', 'info_score_4',
                               'info_score_2', 'bary_score',
                               'info_score_3', 'new_optiscore', 'mover_score']:
                scores = metric.evaluate_batch(candidate_sentences, references, idf_dict_hyp, idf_dict_ref)
            else:
                scores = metric.evaluate_batch(candidate_sentences, references)
            data[key]["system"][system_name]["scores"].update(scores)

    if args.wasserstein_data:
        DATA_PATH = 'data/cnn_bis_wasserstein_new/cnn_abs_formated_{}.json' if args.use_abs else 'data/cnn_bis_wasserstein_new/cnn_ext_formated_{}.json'
    else:
        DATA_PATH = 'data/cnn_bis/cnn_abs_formated_{}.json' if args.use_abs else 'data/cnn_bis/cnn_ext_formated_{}.json'
    print('Saving IN')
    print(DATA_PATH)
    with open(DATA_PATH.format(args.chunk_id), 'w') as file:
        json.dump(data, file)
