
import argparse
import logging
import os
import json
from tqdm import tqdm
import copy
import random
import numpy as np
import torch
from summ_eval.mover_score_metric import MoverScoreMetric
from readers import *
from metric import *
import scipy.stats as stats
from utils import set_seed
import logging
import utils
from transformers import AutoTokenizer

logger = logging.getLogger(__name__)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--seed", default=1531, type=int, help="The input training data file (a text file).")
    parser.add_argument("--chunk_id", default=2, type=int, help="The input training data file (a text file).")
    parser.add_argument("--metric", default='meteor', type=str, choices=available_metric.keys())
    parser.add_argument("--language", default='cs-en', type=str,
                        choices=['cs-en', 'de-en', 'ru-en', 'fi-en', 'ro-en', 'tr-en'])
    parser.add_argument("--use_idf_weights", action='store_true')
    parser.add_argument("--invert_support", action='store_true')
    parser.add_argument("--use_lm", action='store_true')
    parser.add_argument("--use_six", action='store_true')
    parser.add_argument("--suffix", type=str)
    parser.add_argument("--not_use_div", action='store_true')
    parser.add_argument("--wasserstein_data", action='store_true')
    # Setup logging
    args = parser.parse_args()
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)
    # Set seed
    set_seed(args)
    args.logger = logger
    if args.metric == 'bert_score':
        model_name = 'bert-base-uncased'
        if os.path.exists('/gpfswork/rech/qsq/uwi62ct/transformers_models/{}/'.format(model_name)):
            print('TOTTOTOTOT')
            metric = BertScoreMetric(model_type='/gpfswork/rech/qsq/uwi62ct/transformers_models/{}/'.format(model_name))
        else:
            metric = BertScoreMetric(model_type='bert-base-uncased')
    else:
        metric = available_metric[args.metric]
    metric.use_idf_weights = args.use_idf_weights
    try:
        metric.invert_support = args.invert_support
    except:
        pass
    DATA_PATH = 'data/very_new/wmt16/{}_formated.json' if args.use_six else 'data/very_new/wmt15/{}_formated.json'
    with open(DATA_PATH.format(args.language), 'r') as file:
        data = json.load(file)
    logger.info('Data Length {}'.format(len(data)))
    logger.info('Chunk id {}'.format(args.chunk_id))

    model_name = 'bert-base-uncased'
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
    except:
        tokenizer = AutoTokenizer.from_pretrained(
            '/gpfswork/rech/qsq/uwi62ct/transformers_models/bert-base-uncased/')

    references = [value['references_sentences'] for _, value in data.items()]
    candidates = sum([[sub_value['generated_sentence'] for _, sub_value in value['system'].items()] for _, value in
                      data.items()], [])

    idf_dict_ref = utils.ref_list_to_idf(references) if args.metric != 'mover_score' else metric.get_idf_dict(
        references)
    idf_dict_hyp = utils.ref_list_to_idf(candidates) if args.metric != 'mover_score' else metric.get_idf_dict(
        candidates)

    if args.metric in ['bary_score']:
        metric.prepare_idfs(references, candidates)

    for key, value in tqdm(data.items(), "Progress In Chunk"):
        references = value["references_sentences"]
        for system_name, value_name in value["system"].items():
            candidate_sentences = value_name['generated_sentence']
            if args.metric in ['info_score', 'info_mover_score', 'wasserstein', 'info_score_4', 'info_score_2',
                               'info_score_3', 'new_optiscore', 'mover_pluplu', 'bary_score', 'mover_score']:
                scores = metric.evaluate_batch(candidate_sentences, references, idf_dict_hyp, idf_dict_ref)
            else:
                if args.metric == 'sent_mover':
                    scores = metric.evaluate_batch(candidate_sentences, [references])
                    print('Score', scores)
                elif args.metric == 'meteor':
                    scores = metric.evaluate_batch([candidate_sentences], [references])
                elif args.metric != 'bert_score':
                    scores = metric.evaluate_batch(candidate_sentences, [references])
                else:
                    scores = metric.evaluate_batch([candidate_sentences], [references])
            data[key]["system"][system_name]["scores"].update(scores)
    print('Saving IN')
    print(DATA_PATH)
    DATA_PATH = 'data/very_new/wmt16/{}_formated.json' if args.use_six else 'data/very_new/wmt15/{}_formated.json'

    print(data)
    with open(DATA_PATH.format(args.language), 'w') as file:
        json.dump(data, file)
