# export CORENLP_HOME=/gpfsdswork/projects/rech/qsq/uwi62ct/SummEval/evaluation/summ_eval/stanford-corenlp-full-2018-10-05/
# https://github.com/bheinzerling/pyrouge/issues/8
# split https://arxiv.org/pdf/1806.06422.pdf -> 18

import os

import itertools
import csv
import argparse
from summ_eval.bert_score_metric import BertScoreMetric
import logging
import os
import json
from tqdm import tqdm
import copy
import random
import numpy as np
import torch
from readers import *
from metric import *
import scipy.stats as stats
from utils import set_seed
import logging
import utils

logger = logging.getLogger(__name__)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--seed", default=1531, type=int, help="The input training data file (a text file).")
    parser.add_argument("--chunk_id", default=0, type=int, help="The input training data file (a text file).")
    parser.add_argument("--metric", default='info_score', type=str, choices=available_metric.keys())
    parser.add_argument("--lang", default='en', type=str, choices=['en', 'es', 'fr', 'ja', 'ko', 'zh'])
    parser.add_argument("--not_use_div", action='store_true')
    parser.add_argument("--use_idf_weights", action='store_true')
    parser.add_argument("--use_normal_duplicate", action='store_true')
    parser.add_argument("--multilingual", action='store_true')
    # Setup logging
    args = parser.parse_args()
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)
    # Set seed
    set_seed(args)
    args.logger = logger

    metric = available_metric[args.metric]
    metric.use_idf_weights = args.use_idf_weights
    if args.metric == 'bert_score':
        try:
            metric = BertScoreMetric(lang='/gpfswork/rech/qsq/uwi62ct/transformers_models/bert-base-uncased',
                                     model_type='/gpfswork/rech/qsq/uwi62ct/transformers_models/bert-base-uncased' if args.use_normal_duplicate else '/gpfswork/rech/qsq/uwi62ct/transformers_models/bert-base-multilingual-uncased')

        except:
            metric = BertScoreMetric(
                model_type='bert-base-uncased' if args.use_normal_duplicate else 'bert-base-multilingual-uncased')

    elif metric in ['info_score','info_mover_score'] and not args.use_normal_duplicate:
        metric = available_metric[args.metric]
        metric.update_multilingual_model()
    lang = args.lang
    DATA_PATH = 'data/duplicated_questions/duplicate_questions_formated_{}_{}.json' if args.use_normal_duplicate else "data/x-final/test_{}_formated_{}.json"
    NUMBER_OF_CHUNKS = 20
    assert args.chunk_id < NUMBER_OF_CHUNKS
    assert args.chunk_id >= 0
    with open(DATA_PATH.format(lang, args.chunk_id), 'r') as file:
        data = json.load(file)

    if args.metric in ['info_score', 'info_mover_score']:
        references = []
        candidates = []
        for i in range(NUMBER_OF_CHUNKS):
            DATA_PATH = 'data/duplicated_questions/duplicate_questions_formated_{}_{}.json' if args.use_normal_duplicate else "data/x-final/test_{}_formated_{}.json"
            with open(DATA_PATH.format(lang, i), 'r') as file:
                data_idf = json.load(file)
            reference = [value['q1'] for _, value in data_idf.items()]
            references += reference
            candidate = [value['q2'] for _, value in data_idf.items()]
            candidates += candidate
        model_name = 'bert-base-uncased' if args.use_normal_duplicate else 'bert-base-multilingual-uncased'
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_name)
        except:
            tokenizer = AutoTokenizer.from_pretrained(
                '/gpfswork/rech/qsq/uwi62ct/transformers_models/{}/'.format(model_name))
        references = tokenizer(references)['input_ids']
        idf_dict_ref = utils.ref_list_to_idf(references)
        candidates = tokenizer(candidates)['input_ids']
        idf_dict_hyp = utils.ref_list_to_idf(candidates)
    logger.info('Data Length {}'.format(len(data)))
    logger.info('Chunk id {}'.format(args.chunk_id))

    for key, value in tqdm(data.items(), "Progress In Chunk"):
        references = value["q1"]
        candidate_sentences = value["q2"]
        if args.metric in ['info_score', 'info_mover_score']:
            scores = metric.evaluate_batch(candidate_sentences, references, idf_dict_hyp, idf_dict_ref)
        else:
            scores = metric.evaluate_batch([references], [candidate_sentences])
        value["scores"].update(scores)

    DATA_PATH = 'data/duplicated_questions/duplicate_questions_formated_en_{}.json' if args.use_normal_duplicate else "data/x-final/test_{}_formated_{}.json"
    with open(DATA_PATH.format(args.chunk_id) if args.use_normal_duplicate else DATA_PATH.format(lang, args.chunk_id),
              'w') as file:
        json.dump(data, file)
