import numpy as np

from scipy.sparse.csgraph import minimum_spanning_tree
from utils import sigmoid, build_query_set, best_basis_h

class Environment:

    def __init__(self, S, A, H, r_max, r_min, min_r_stage=True, initial_state=None):
        assert S > 0, "Number of states cannot be zero"
        assert A > 0, "Number of action cannot be zero"
        assert H > 0, "Horizon cannot be zero"
        assert r_max > r_min, "Maximum instantaneous reward must be greater than the minimum"
        assert initial_state is None or (initial_state >=0 and initial_state < S), "Initial state distribution must be either None (i.e., random distribution) or a valid state in [0, S)"

        self.S = S
        self.A = A
        self.H = H
        self.r_max = r_max
        self.r_min = r_min
        self.min_r_stage = min_r_stage
        self.initial_state = initial_state # value None corresponds to random distribution

        self._initialize()
        self._reset()

    def _initialize(self):
        self.P = np.random.uniform(size=(self.S, self.A, self.H, self.S))
        for i in range(self.S):
            for j in range(self.A):
                for h in range(self.H):
                    self.P[i, j, h, :] = self.P[i, j, h, :] / np.sum(self.P[i, j, h, :])

        self.r = np.random.rand(self.S, self.A, self.H)*(self.r_max - self.r_min) + self.r_min

        if self.min_r_stage: # Set min reward (for each stage independently) to r_min, for ease of evaluation
            for i in range(self.H):
                r_h = self.r[:, :, i]
                self.r[:, :, i] = r_h - np.min(r_h) + self.r_min

        if self.initial_state is not None:
            self.mu = np.zeros((self.S,1))
            self.mu[int(self.initial_state)] = 1
        else:
            self.mu = np.random.uniform(size=self.S).reshape(-1, 1)
            self.mu = self.mu / np.sum(self.mu)

    def _reset(self):
        # Restore initial position, either to fixed state or sampled from distribution
        pass

    def get_cumulative_reward(self, traj):
        assert traj.shape == tuple(np.array([self.H])) or traj.shape == tuple(np.array([self.H, 1])), "Trajectory must be an array of length H"
        assert np.all((traj >= 0) & (traj < self.S * self.A)), "Trajectory must contain valid state-action pair in range [0, SA)"

        return self.r[traj//self.S, traj%self.A, np.arange(self.H)].sum()

    def compare_policy_value_true_est(self, r_est):
        assert r_est.shape == self.r.shape

        Q_star = np.zeros((self.S, self.A, self.H))
        Q_est = np.zeros((self.S, self.A, self.H))
        Q_star_est = np.zeros((self.S, self.A, self.H))
        V_star = np.zeros((self.S, self.H+1))
        V_est = np.zeros((self.S, self.H+1))
        V_star_est = np.zeros((self.S, self.H+1))

        for h in range(self.H-1, -1, -1):
            for s in range(self.S):
                for a in range(self.A):
                    Q_star[s, a, h] = self.r[s, a, h] + self.P[s, a, h, :]@V_star[:, h+1]         # \pi^* wrt true rewards
                    Q_est[s, a, h] = r_est[s, a, h] + self.P[s, a, h, :]@V_est[:, h+1]       # \hat{\pi} (optimal wrt est rewards)
                    Q_star_est[s, a, h] = self.r[s, a, h] + self.P[s, a, h, :]@V_star_est[:, h+1] # \hat{\pi} wrt true rewards
                V_star[s, h] = np.max(Q_star[s, :, h]) # \pi^* wrt true rewards
                V_est[s, h] = np.max(Q_est[s, :, h]) # \hat{\pi} (optimal wrt est rewards)
                V_star_est[s, h] = Q_star_est[s, np.argmax(Q_est[s, :, h]), h] # \hat{\pi} wrt true rewards

        return V_star, V_star_est
    
    def get_variances_h(self, h):
        assert h>=0 and h<self.H, "Requested stage must be in [0, H)."

        reward = self.r[:, :, h].reshape(-1, 1)
        r1 = np.tile(reward, (self.S*self.A))
        r2 = r1.T

        delta_r = r1 - r2
        p = sigmoid(delta_r)
        return np.multiply(p, 1-p)

    def get_optimal_basis(self, h):
        assert h>=0 and h<self.H, "Requested stage must be in [0, H)."

        G = self.get_variances_h(h)
        Q = build_query_set(self.S, self.A, self.H, h)

        W = np.zeros((Q.shape[0], 1))
        pairs = Q[:,[h,self.H+h]]

        for j in range(W.shape[0]):
            W[j] = G[pairs[j,0], pairs[j,1]]

        B = best_basis_h(Q, W, self.H, h)

        basis = np.zeros((B.shape[0], 2*self.H), dtype=int)
        basis[:,h] = B[:,0]
        basis[:,self.H+h] = B[:,1]

        return basis
    
    def get_true_index(self, basis, h):
        assert basis.shape == tuple(np.array([self.S*self.A-1, 2*self.H]))
        assert np.issubdtype(basis.dtype, np.integer), "Basis must contain only integers."
        assert h>=0 and h<self.H, "Requested stage must be in [0, H)."

        G = self.get_variances_h(h)
        pairs = basis[:, [h, self.H+h]]

        return min(G[pairs[:,0], pairs[:,1]])