# export CORENLP_HOME=/gpfsdswork/projects/rech/qsq/uwi62ct/SummEval/evaluation/summ_eval/stanford-corenlp-full-2018-10-05/
# https://github.com/bheinzerling/pyrouge/issues/8
# split https://arxiv.org/pdf/1806.06422.pdf -> 18
#
import os

import itertools
import csv
from summ_eval.bert_score_metric import BertScoreMetric
import argparse
import logging
import os
import json
from tqdm import tqdm
import copy
import utils
import random
import numpy as np
import torch
from readers import *
from metric import *
from transformers import AutoTokenizer
import scipy.stats as stats
from utils import set_seed
import logging

logger = logging.getLogger(__name__)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--seed", default=1531, type=int, help="The input training data file (a text file).")
    parser.add_argument("--chunk_id", default=8, type=int, help="The input training data file (a text file).")
    parser.add_argument("--metric", default='info_score', type=str, choices=available_metric.keys())
    parser.add_argument("--suffix", default='info_score', type=str)
    parser.add_argument("--use_abs", action='store_true')
    parser.add_argument("--invert_support", action='store_true')
    parser.add_argument("--use_idf_weights", action='store_true')
    parser.add_argument("--use_lm", action='store_true')
    parser.add_argument("--not_use_div", action='store_true')
    parser.add_argument("--wasserstein_data", action='store_true')

    # Setup logging
    args = parser.parse_args()
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)
    # Set seed
    set_seed(args)
    args.logger = logger

    metric = available_metric[args.metric]
    NUMBER_OF_CHUNKS = 20
    metric.use_idf_weights = args.use_idf_weights
    assert args.chunk_id < NUMBER_OF_CHUNKS
    assert args.chunk_id >= 0
    try:
        metric.invert_support = args.invert_support
    except:
        pass
    DATA_PATH = 'data/coco/coco_formated_{}.json'
    with open(DATA_PATH.format(args.chunk_id), 'r') as file:
        data = json.load(file)
    if args.metric == 'bert_score':
        model_name = 'bert-base-uncased'
        if os.path.exists('/gpfswork/rech/qsq/uwi62ct/transformers_models/{}/'.format(model_name)):
            print('TOTTOTOTOT')
            metric = BertScoreMetric(model_type='/gpfswork/rech/qsq/uwi62ct/transformers_models/{}/'.format(model_name))
        else:
            metric = BertScoreMetric(model_type='bert-base-uncased')

    model_name = 'bert-base-uncased'
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
    except:
        tokenizer = AutoTokenizer.from_pretrained(
            '/gpfswork/rech/qsq/uwi62ct/transformers_models/bert-base-uncased/')
    l_refs = []
    l_cands = []
    f_candidates = []
    f_references = []
    for i in tqdm(range(NUMBER_OF_CHUNKS), 'Chunks IDF'):
        DATA_PATH = 'data/coco/coco_formated_{}.json'
        with open(DATA_PATH.format(i), 'r') as file:
            data_idf = json.load(file)
        references = [value['references_sentences'] for _, value in data_idf.items()]
        for ref in references:
            l_refs.append(tokenizer(ref)['input_ids'])

        candidates = [[sub_value['generated_sentence'] for _, sub_value in value['system'].items()] for _, value in
                      data_idf.items()]
        for cand in candidates:
            l_cands.append(tokenizer(cand)['input_ids'])

        f_candidates += sum(candidates, [])
        f_references += sum(references, [])

    idf_dict_ref = utils.ref_list_to_idf(sum(l_refs, [])) if args.metric != 'mover_score' else metric.get_idf_dict(
        f_references)
    idf_dict_hyp = utils.ref_list_to_idf(sum(l_cands, [])) if args.metric != 'mover_score' else metric.get_idf_dict(
        f_candidates)
    logger.info('Data Length {}'.format(len(data)))
    logger.info('Chunk id {}'.format(args.chunk_id))

    if args.metric in ['bary_score']:
        metric.prepare_idfs(f_references, f_candidates)

    for key, value in tqdm(data.items(), "Progress In Chunk"):
        references = value["references_sentences"]
        for system_name, value_system in value["system"].items():
            candidate_sentences = [value_system["generated_sentence"]] * len(references)
            if args.metric in ['info_score', 'info_mover_score', 'mover_pluplu', 'wasserstein', 'info_score_4',
                               'info_score_2', 'bary_score',
                               'info_score_3', 'new_optiscore', 'mover_score']:
                scores = metric.evaluate_batch(candidate_sentences, references, idf_dict_hyp, idf_dict_ref)
            else:
                scores = metric.evaluate_batch(candidate_sentences, references)
            data[key]["system"][system_name]["scores"].update(scores)

    DATA_PATH = 'data/coco/coco_formated_{}.json'
    print('Saving IN')
    print(DATA_PATH)
    with open(DATA_PATH.format(args.chunk_id), 'w') as file:
        json.dump(data, file)
