import argparse
import logging
import os
import json
from tqdm import tqdm
import copy
import random
import numpy as np
import torch
from readers import *
from metric import *
import scipy.stats as stats
from utils import set_seed
import utils
from transformers import AutoTokenizer
import logging

logger = logging.getLogger(__name__)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--seed", default=1531, type=int, help="The input training data file (a text file).")
    parser.add_argument("--chunk_id", default=2, type=int, help="The input training data file (a text file).")
    parser.add_argument("--metric", default='ter', type=str)  # choices=available_metric.keys())
    parser.add_argument("--language", default='cs-en', type=str,
                        choices=['cs-en', 'de-en', 'ru-en', 'tr-en', 'zh-en'])
    parser.add_argument("--use_idf_weights", action='store_true')
    parser.add_argument("--invert_support", action='store_true')
    parser.add_argument("--use_lm", action='store_true')
    parser.add_argument("--suffix", type=str)
    parser.add_argument("--not_use_div", action='store_true')
    parser.add_argument("--wasserstein_data", action='store_true')
    # Setup logging
    args = parser.parse_args()
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)
    # Set seed
    set_seed(args)
    args.logger = logger
    if args.metric == 'bert_score':
        model_name = 'bert-base-uncased'
        if os.path.exists('/gpfswork/rech/qsq/uwi62ct/transformers_models/{}/'.format(model_name)):
            metric = BertScoreMetric(model_type='/gpfswork/rech/qsq/uwi62ct/transformers_models/{}/'.format(model_name))
        else:
            metric = BertScoreMetric(model_type='bert-base-uncased')
    else:
        metric = available_metric[args.metric]
    metric.use_idf_weights = args.use_idf_weights
    try:
        metric.invert_support = args.invert_support
    except:
        pass
    DATA_PATH = 'data/wmt18/{}_processed/{}_formated_{}.json' if not args.wasserstein_data else 'data/wmt18/{}_processed_wasserstein_new/{}_formated_{}.json'
    NUMBER_OF_CHUNKS = 5 if not args.wasserstein_data else 5
    assert args.chunk_id < NUMBER_OF_CHUNKS
    assert args.chunk_id >= 0
    with open(DATA_PATH.format(args.language, args.language, args.chunk_id), 'r') as file:
        data = json.load(file)
    logger.info('Data Length {}'.format(len(data)))
    logger.info('Chunk id {}'.format(args.chunk_id))

    model_name = 'bert-base-uncased'
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
    except:
        tokenizer = AutoTokenizer.from_pretrained(
            '/gpfswork/rech/qsq/uwi62ct/transformers_models/bert-base-uncased/')
    l_refs = []
    l_cands = []
    f_candidates, f_references = [], []
    for i in range(NUMBER_OF_CHUNKS):
        DATA_PATH = 'data/wmt18/{}_processed/{}_formated_{}.json'.format(args.language, args.language,
                                                                         i) if not args.wasserstein_data else 'data/wmt18/{}_processed_wasserstein_new/{}_formated_{}.json'.format(
            args.language, args.language, i)
        with open(DATA_PATH.format(i), 'r') as file:
            data_idf = json.load(file)
        references = [value['references_sentences'] for _, value in data_idf.items()]
        for ref in references:
            l_refs.append(tokenizer(ref)['input_ids'])

        candidates = [[sub_value['generated_sentence'] for _, sub_value in value['system'].items()] for _, value in
                      data_idf.items()]
        for cand in candidates:
            l_cands.append(tokenizer(cand)['input_ids'])

        f_candidates += sum(candidates, [])
        f_references += references
    idf_dict_ref = utils.ref_list_to_idf(l_refs) if args.metric != 'mover_score' else metric.get_idf_dict(f_references)
    idf_dict_hyp = utils.ref_list_to_idf(sum(l_cands, [])) if args.metric != 'mover_score' else metric.get_idf_dict(
        f_candidates)

    i = 0

    if args.metric in ['bary_score']:
        metric.prepare_idfs(f_references, f_candidates)

    for key, value in tqdm(data.items(), "Progress In Chunk"):
        i += 1
        references = value["references_sentences"]
        for system_name, value_name in value["system"].items():
            candidate_sentences = value_name['generated_sentence']
            if args.metric in ['info_score', 'info_mover_score', 'mover_score', 'wasserstein', 'info_score_4',
                               'info_score_2',
                               'info_score_3', 'new_optiscore', 'mover_pluplu', 'bary_score']:
                scores = metric.evaluate_batch(candidate_sentences, references, idf_dict_hyp, idf_dict_ref)
            else:
                if args.metric != 'bert_score':
                    scores = metric.evaluate_batch(candidate_sentences, [references])
                else:
                    scores = metric.evaluate_batch([candidate_sentences], [references])
            data[key]["system"][system_name]["scores"].update(scores)
    print('Saving IN')
    print(DATA_PATH)
    DATA_PATH = 'data/wmt18/{}_processed/{}_formated_{}.json' if not args.wasserstein_data else 'data/wmt18/{}_processed_wasserstein_new/{}_formated_{}.json'

    print(data)
    with open(DATA_PATH.format(args.language, args.language, args.chunk_id), 'w') as file:
        json.dump(data, file)
