
import argparse
import logging
import os
import json
from tqdm import tqdm
import copy
import random
import numpy as np
import torch
from readers import *
from metric import *
import scipy.stats as stats
from utils import set_seed
import utils
from transformers import AutoTokenizer
import logging

logger = logging.getLogger(__name__)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--seed", default=1531, type=int, help="The input training data file (a text file).")
    parser.add_argument("--chunk_id", default=0, type=int, help="The input training data file (a text file).")
    parser.add_argument("--metric", default='bary_score', type=str, choices=available_metric.keys())
    parser.add_argument("--use_idf_weights", action='store_true')
    parser.add_argument("--invert_support", action='store_true')
    parser.add_argument("--use_lm", action='store_true')
    parser.add_argument("--suffix", type=str)
    parser.add_argument("--not_use_div", action='store_true')
    parser.add_argument("--wasserstein_data", action='store_true')

    # Setup logging
    args = parser.parse_args()
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)
    # Set seed
    set_seed(args)
    args.logger = logger

    metric = available_metric[args.metric]
    metric.use_idf_weights = args.use_idf_weights
    try:
        metric.invert_support = args.invert_support
    except:
        pass
    DATA_PATH = 'data/tac2009_wasserstein_new/tac2009_formated_{}.json' if args.wasserstein_data else 'data/tac2009/tac2009_formated_{}.json'
    NUMBER_OF_CHUNKS = 50 if args.wasserstein_data else 10
    assert args.chunk_id < NUMBER_OF_CHUNKS
    assert args.chunk_id >= 0
    with open(DATA_PATH.format(args.chunk_id), 'r') as file:
        data = json.load(file)
    logger.info('Data Length {}'.format(len(data)))
    logger.info('Chunk id {}'.format(args.chunk_id))
    model_name = 'bert-base-uncased'
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
    except:
        tokenizer = AutoTokenizer.from_pretrained(
            '/gpfswork/rech/qsq/uwi62ct/transformers_models/bert-base-uncased/')
    l_refs = []
    l_cands = []
    f_candidates, f_references = [], []
    if args.metric == 'bert_score':
        model_name = 'bert-base-uncased'
        if os.path.exists('/gpfswork/rech/qsq/uwi62ct/transformers_models/{}/'.format(model_name)):
            print('TOTTOTOTOT')
            metric = BertScoreMetric(model_type='/gpfswork/rech/qsq/uwi62ct/transformers_models/{}/'.format(model_name))
        else:
            metric = BertScoreMetric(model_type='bert-base-uncased')
    for i in range(NUMBER_OF_CHUNKS):
        DATA_PATH = 'data/tac2009_wasserstein_new/tac2009_formated_{}.json' if args.wasserstein_data else 'data/tac2009/tac2009_formated_{}.json'
        with open(DATA_PATH.format(i), 'r') as file:
            data_idf = json.load(file)
        references = [value['references_sentences'] for _, value in data_idf.items()]
        for ref in references:
            l_refs.append(tokenizer(ref)['input_ids'])

        candidates = [value['generated_sentence'] for _, value in data_idf.items()]
        for cand in candidates:
            l_cands.append(tokenizer(cand)['input_ids'])
        f_candidates += sum(candidates, [])
        f_references += sum(references, [])
    l_refs = sum(l_refs, [])
    l_cands = sum(l_cands, [])
    idf_dict_ref = utils.ref_list_to_idf(l_refs) if args.metric != 'mover_score' else metric.get_idf_dict(f_references)
    idf_dict_hyp = utils.ref_list_to_idf(l_cands) if args.metric != 'mover_score' else metric.get_idf_dict(f_candidates)

    if args.metric in ['bary_score']:
        metric.prepare_idfs(f_references, f_candidates)

    for key, value in tqdm(data.items(), "Progress In Chunk"):
        references = value["references_sentences"]
        candidate_sentences = value["generated_sentence"]

        if args.metric in ['info_score', 'info_mover_score', 'wasserstein', 'new_optiscore', 'bary_score',
                           'mover_score']:

            scores = metric.evaluate_batch(candidate_sentences, references, idf_dict_hyp, idf_dict_ref)
        else:
            scores = metric.evaluate_batch(references, candidate_sentences)
        data[key]["scores"].update(scores)
    DATA_PATH = 'data/tac2009_wasserstein_new/tac2009_formated_{}.json' if args.wasserstein_data else 'data/tac2009/tac2009_formated_{}.json'
    with open(DATA_PATH.format(args.chunk_id), 'w') as file:
        json.dump(data, file)
