import pprint
import random
import torch
from math import log
from collections import defaultdict, Counter
import numpy as np

pp = pprint.PrettyPrinter(indent=4)


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)


def print_average_correlation(all_results):
    if isinstance(all_results, dict):
        for key, value in all_results.items():
            print('Information Metric {}'.format(key))
            corr_mat = np.array(value)
            results = dict(zip(['kendall', 'pearson', 'spearman'],
                               [np.mean(corr_mat[:, 0]),
                                np.mean(corr_mat[:, 1]),
                                np.mean(corr_mat[:, 2])]))
            pp.pprint(results)
    else:
        corr_mat = np.array(all_results)
        try:
            results = dict(zip(['kendall', 'pearson', 'spearman'],
                               [np.mean(corr_mat[:, 0]),
                                np.mean(corr_mat[:, 1]),
                                np.mean(corr_mat[:, 2])]))
        except:
            print()
        pp.pprint(results)


def ref_list_to_idf(input_refs):
    idf_count = Counter()
    num_docs = len(input_refs)

    idf_count.update(sum([list(set(i)) for i in input_refs], []))

    idf_dict = defaultdict(lambda: log((num_docs + 1) / (1)))
    idf_dict.update({idx: log((num_docs + 1) / (c + 1)) for (idx, c) in idf_count.items()})
    return idf_dict
