
import itertools
import csv
import utils
import argparse
import logging
import os
import json
from tqdm import tqdm
import copy
import random
import numpy as np
import torch
from summ_eval.mover_score_metric import MoverScoreMetric
from readers import *
from metric import *
import scipy.stats as stats
from utils import set_seed
import logging

logger = logging.getLogger(__name__)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--seed", default=1531, type=int, help="The input training data file (a text file).")
    parser.add_argument("--chunk_id", default=0, type=int, help="The input training data file (a text file).")
    parser.add_argument("--metric", default='info_score', type=str, choices=available_metric.keys())
    parser.add_argument("--use_idf_weights", action='store_true')
    parser.add_argument("--not_use_div", action='store_true')
    # Setup logging
    args = parser.parse_args()
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)
    # Set seed
    set_seed(args)
    args.logger = logger
    if args.metric == 'bert_score':
        model_name = 'bert-base-uncased'
        if os._exists('/gpfswork/rech/qsq/uwi62ct/transformers_models/{}/'.format(model_name)):
            metric = BertScoreMetric(model_type='/gpfswork/rech/qsq/uwi62ct/transformers_models/{}/'.format(model_name))
        else:
            metric = BertScoreMetric(model_type='bert-base-uncased')
    else:
        metric = available_metric[args.metric]
    metric.use_idf_weights = args.use_idf_weights
    DATA_PATH = 'data/msr_final/test_formated_{}.json'
    NUMBER_OF_CHUNKS = 5
    assert args.chunk_id < NUMBER_OF_CHUNKS
    assert args.chunk_id >= 0
    with open(DATA_PATH.format(args.chunk_id), 'r') as file:
        data = json.load(file)
    logger.info('Data Length {}'.format(len(data)))
    logger.info('Chunk id {}'.format(args.chunk_id))

    if args.metric in ['info_score', 'mover_info_score']:
        references = []
        for i in range(NUMBER_OF_CHUNKS):
            DATA_PATH = 'data/msr_final/test_formated_{}.json'.format(i)
            with open(DATA_PATH, 'r') as file:
                data_idf = json.load(file)
            reference = [value['references_sentences'] for _, value in data_idf.items()]
            references += reference
        model_name = 'bert-base-uncased'
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_name)
        except:
            tokenizer = AutoTokenizer.from_pretrained(
                '/gpfswork/rech/qsq/uwi62ct/transformers_models/bert-base-uncased/')
        references = tokenizer(references)['input_ids']
        idf_dict_ref = utils.ref_list_to_idf(references)

    for key, value in tqdm(data.items(), "Progress In Chunk"):
        references = value["references_sentences"]
        for system_name, candidate_sentences in value["system"].items():
            if args.metric in ['info_score', 'mover_info_score']:
                idf_dict_hyp = utils.ref_list_to_idf(candidate_sentences)
                scores = metric.evaluate_batch(candidate_sentences, [references] * len(candidate_sentences),
                                               idf_dict_hyp, idf_dict_ref)
            else:
                scores = metric.evaluate_batch(candidate_sentences, [references] * len(candidate_sentences))
            data[key]["scores"].update(scores)
    print('Saving IN')
    print(DATA_PATH)
    DATA_PATH = 'data/msr_final/test_formated_{}.json'
    with open(DATA_PATH.format(args.chunk_id), 'w') as file:
        json.dump(data, file)
