import numpy as np


# chain MDP
class ChainMDP(object):
    def __init__(self, length, num_chains=10, feature_dim=5, alpha=10):
        self.length = length
        self.num_chains = num_chains
        self.features = np.random.randn(num_chains, length, feature_dim)
        self.reward_param = np.random.randn(feature_dim)
        self.reward = self.features @ self.reward_param
        self.alpha = alpha

    def query(self, i, j):
        assert 0 <= i < self.num_chains and 0 <= j < self.num_chains, "invalid query"
        Ri = np.average(self.reward[i, :])
        Rj = np.average(self.reward[j, :])
        phii = np.average(self.features[i, :, :], axis=0)
        phij = np.average(self.features[j, :, :], axis=0)
        p = 1 / (np.exp(-1 * self.alpha * (Ri - Rj)) + 1)
        o = np.random.binomial(1, p, 1)
        return phii, phij, o

    def weighted_query(self, i, j, wi, wj):
        assert 0 <= i < self.num_chains and 0 <= j < self.num_chains, "invalid query"
        assert len(wi) == self.length and len(wj) == self.length, "invalid query"

        Ri = np.dot(wi, self.reward[i, :]) / self.length
        Rj = np.dot(wj, self.reward[j, :]) / self.length
        phii = np.dot(wi, self.features[i, :, :]) / self.length
        phij = np.dot(wj, self.features[j, :, :]) / self.length
        p = 1 / (np.exp(-1 * self.alpha * (Ri - Rj)) + 1)
        o = np.random.binomial(1, p, 1)
        return phii, phij, o

    # def evaluate(self, sequence):
    #     R = 0
    #     assert len(sequence) == self.length, "invalid sequence"
    #     assert min(sequence) >= 0 and max(sequence) < self.num_chains, "invalid sequence"
    #     for i, s in enumerate(sequence):
    #         R += self.reward[s, i]
    #     return

    def evaluate(self, theta=None):
        if theta is None:
            rmax = np.max(self.reward, axis=0)
            return np.sum(rmax)

        else:
            reward = self.features @ theta
            policy = np.argmax(reward, axis=0)
            R = 0
            for i, s in enumerate(policy):
                R += self.reward[s, i]
            return R
