#!/usr/bin/env python3
"""Approximate information gain using a cooccurrence matrix to create a Dirichlet distribution over transition matrices."""

import pickle
import sys
import numpy as np
import tqdm
from scipy.special import digamma, loggamma

def log_beta(alpha_non_one, num_ones, alpha_sum):
    """The logarithm of the Beta function in the definition of the Dirichlet entropy function."""
    return np.sum(loggamma(alpha_non_one)) + (loggamma(1) * num_ones) - loggamma(alpha_sum)

def dirichlet_entropy(alpha_minus_one, size):
    """The entropy of a Dirichlet distribution."""
    nonzero_indices = list(alpha_minus_one.keys())
    num_ones = size - len(nonzero_indices)
    alpha_non_one = np.array([alpha_minus_one[i] for i in nonzero_indices]) + 1
    alpha_sum = alpha_non_one.sum() + num_ones
    num_categories = alpha_non_one.shape[0] + num_ones
    entropy = log_beta(alpha_non_one, num_ones, alpha_sum)
    entropy += ((alpha_sum - num_categories) * digamma(alpha_sum))
    entropy -= np.dot(alpha_non_one - 1, digamma(alpha_non_one))
    return entropy

print('Loading cooccurrence matrix')
with open(sys.argv[1], 'rb') as f:
    cooccurrence_counts = pickle.load(f)
print('Generating observation matrix')
obs_matrix = {}
max_obs = -1
total_transitions = 0
for (obs, act, next_obs), count in cooccurrence_counts.items():
    total_transitions += count

    key = (obs, act)
    if key not in obs_matrix:
        obs_matrix[key] = {}
    if next_obs not in obs_matrix[key]:
        obs_matrix[key][next_obs] = 0
    obs_matrix[key][next_obs] += count

    for obs_number in [obs, next_obs]:
        if obs > max_obs:
            max_obs = obs

print('Calculating Dirichlet entropy')
initial_row_entropy = dirichlet_entropy({}, max_obs + 1)
total_info_gain = sum(tqdm.tqdm(initial_row_entropy - dirichlet_entropy(row, max_obs + 1) for row in obs_matrix.values()))
mean_info_gain = total_info_gain / total_transitions
print('Mean information gain:', mean_info_gain)
