import argparse
import logging
import os
import json
from tqdm import tqdm
import copy
import random
import numpy as np
import torch
from readers import *
from metric import *
import scipy.stats as stats
from utils import set_seed
import utils
from transformers import AutoTokenizer
import logging

logger = logging.getLogger(__name__)
if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--seed", default=1531, type=int, help="The input training data file (a text file).")
    parser.add_argument("--chunk_id", default=1, type=int, help="The input training data file (a text file).")
    parser.add_argument("--metric", default='ter', type=str, choices=available_metric.keys())
    parser.add_argument("--use_idf_weights", action='store_true')
    parser.add_argument("--invert_support", action='store_true')
    parser.add_argument("--not_use_div", action='store_true')
    parser.add_argument("--wasserstein_data", action='store_true')

    # Setup logging
    args = parser.parse_args()
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)
    # Set seed
    set_seed(args)
    args.logger = logger
    metric = available_metric[args.metric]
    try:
        metric.invert_support = args.invert_support
    except:
        pass
    if args.metric == 'bert_score':
        logger.info("Loading TOTO")
        model_name = 'bert-base-uncased'
        if os._exists('/gpfswork/rech/qsq/uwi62ct/transformers_models/{}/'.format(model_name)):
            logger.info("Loading from gpfwork")
            metric = BertScoreMetric(model_type='/gpfswork/rech/qsq/uwi62ct/transformers_models/{}/'.format(model_name))
        else:
            metric = BertScoreMetric(model_type='bert-base-uncased')
    else:
        metric = available_metric[args.metric]

    DATA_PATH = 'data/webnlg-2020_processed/webnlg_2020_en_formated_{}.json' if not args.wasserstein_data else 'data/webnlg-2020_processed_wasserstein_new/webnlg_2020_en_formated_{}.json'
    NUMBER_OF_CHUNKS = 6 if not args.wasserstein_data else 20
    metric.use_idf_weights = args.use_idf_weights
    assert args.chunk_id < NUMBER_OF_CHUNKS
    assert args.chunk_id >= 0
    with open(DATA_PATH.format(args.chunk_id), 'r') as file:
        data = json.load(file)
    model_name = 'bert-base-uncased'
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
    except:
        tokenizer = AutoTokenizer.from_pretrained(
            '/gpfswork/rech/qsq/uwi62ct/transformers_models/bert-base-uncased/')
    l_refs = []
    l_cands = []
    f_candidates, f_references = [], []
    for i in range(NUMBER_OF_CHUNKS):
        DATA_PATH = 'data/webnlg-2020_processed/webnlg_2020_en_formated_{}.json' if not args.wasserstein_data else '../data/webnlg-2020_processed_wasserstein_new/webnlg_2020_en_formated_{}.json'
        with open(DATA_PATH.format(i), 'r') as file:
            data_idf = json.load(file)
        references = [value['references_sentences'] for _, value in data_idf.items()]
        for ref in references:
            l_refs.append(tokenizer(ref)['input_ids'])

        candidates = [[sub_value['generated_sentence'] for _, sub_value in value['system'].items()] for _, value in
                      data_idf.items()]
        for cand in candidates:
            l_cands.append(tokenizer(cand)['input_ids'])
        f_candidates += sum(candidates, [])
        f_references += sum(references, [])
    idf_dict_ref = utils.ref_list_to_idf(sum(l_refs, [])) if args.metric != 'mover_score' else metric.get_idf_dict(
        f_references)
    idf_dict_hyp = utils.ref_list_to_idf(sum(l_cands, [])) if args.metric != 'mover_score' else metric.get_idf_dict(
        f_candidates)

    if args.metric in ['bary_score']:
        metric.prepare_idfs(f_references, f_candidates)

    logger.info('Data Length {}'.format(len(data)))
    logger.info('Chunk id {}'.format(args.chunk_id))
    for key, value in tqdm(data.items(), "Progress In Chunk"):
        references = value["references_sentences"]
        for system_name, value_system in value["system"].items():
            candidate_sentences = [value_system["generated_sentence"]] * len(references)
            if args.metric in ['info_score', 'info_mover_score', 'mover_score', 'wasserstein', 'bary_score']:
                scores = metric.evaluate_batch(candidate_sentences, references, idf_dict_hyp, idf_dict_ref)
            else:
                scores = metric.evaluate_batch(candidate_sentences, references)
            print(data[key]["system"][system_name]["scores"])
            data[key]["system"][system_name]["scores"].update(scores)
    DATA_PATH = 'data/webnlg-2020_processed/webnlg_2020_en_formated_{}.json' if not args.wasserstein_data else '../data/webnlg-2020_processed_wasserstein_new/webnlg_2020_en_formated_{}.json'
    print('Saving IN')
    print(DATA_PATH)
    print(data)
    with open(DATA_PATH.format(args.chunk_id), 'w') as file:
        json.dump(data, file)
