import numpy as np

from scipy.sparse.csgraph import minimum_spanning_tree
from utils import build_query_set, best_basis_h, inverse_sigmoid

class SERE():
    def __init__(self, S, A, H, e, l, d, labeler, extra_args=None):
        self.S = S
        self.A = A
        self.H = H
        self.e = e # epsilon
        self.l = l # lambda
        self.d = d # delta
        self.labeler = labeler # labeler instance

        self.t_stop = np.zeros(H)
        self.T = np.zeros(H)
        self.basis = np.zeros((S*A-1, 2*H, H), dtype=int)

        self.rew_est = np.zeros((self.S, self.A, self.H))
        
        # Auxiliary arguments
        self.t_batch = 1000
        self.t_expl_max = int(1e6)
        self.beta_const = 4
        self.sample_const = 8 * np.e
        self.r_min = 0

        if extra_args is not None:
            if "t_batch" in extra_args:
                self.t_batch = extra_args["t_batch"]
            if "t_max" in extra_args:
                self.t_expl_max = extra_args["t_expl_max"]
            if "beta_const" in extra_args:
                self.beta_const = extra_args["beta_const"]
            if "sample_const" in extra_args:
                self.sample_const = extra_args["sample_const"]
            if "r_min" in extra_args:
                self.r_min = extra_args["r_min"]


    def _beta(self, t):
        t = np.array(t)
        ln_arg = self.beta_const * self.S * self.A * (self.S*self.A - 1) * t**2 / self.d
        return np.sqrt(np.log(ln_arg) / (2 * t))
    
    def _variance_conf_bounds(self, p_est, beta):
        n = p_est.shape[0]

        # Extremes of p_est confidence interval
        p_est_LB = np.maximum(np.zeros((n,1)), p_est - beta)
        p_est_UB = np.minimum(np.ones((n,1)), p_est + beta)

        var_LB = p_est_LB
        var_UB = p_est_UB

        # Account for position of p_est wrt 0.5
        var_LB[np.nonzero(np.multiply(p_est>0.5, p_est_LB<0.5))[0]] = 0.5
        var_UB[np.nonzero(np.multiply(p_est<=0.5, p_est_UB>0.5))[0]] = 0.5

        # Compute variances
        var_LB = np.multiply(var_LB, 1-var_LB)
        var_UB = np.multiply(var_UB, 1-var_UB)

        Var_LB = np.minimum(var_LB, var_UB)
        Var_UB = np.maximum(var_LB, var_UB)

        return Var_LB, Var_UB
    
    def _get_weights_from_basis(self, Q, B, W):
        mask = np.zeros(Q.shape[0], dtype=bool)
        for i in range(B.shape[0]):
            mask[np.where(np.all(Q==B[i], axis=1))[0][0]] = 1

        return W[mask]
    
    def _lambda_SQE(self, Q, h):
        # initialization
        X = Q.copy()
        feedbacks = np.zeros((X.shape[0], 1))
        p_est = np.zeros((X.shape[0], 1))
        t = 1

        V_opt = 0
        V_pes = 0
        ratio = 0

        # main loop
        while True:
            # auxiliary variables
            beta_t = self._beta(t)

            # round robin (batch) samples
            for i in range(X.shape[0]):
                pair = X[i, :]
                traj1 = pair[:self.H]
                traj2 = pair[self.H:]

                feedbacks[i] = feedbacks[i] + self.labeler.generate_preference(traj1, traj2, n=self.t_batch)
            
            p_est = feedbacks/(t+self.t_batch-1)

            L, U = self._variance_conf_bounds(p_est, beta_t)

            I_L = best_basis_h(X, L, self.H, h)            

            V_pes = np.min(self._get_weights_from_basis(X,I_L,L))

            mask = np.ones(X.shape[0], dtype=bool)

            # find suboptimal queries
            for i in range(X.shape[0]):
                I_U_q = best_basis_h(X[mask, :], U[mask], self.H, h, X[i,:])

                V_opt_q = np.min(self._get_weights_from_basis(X[mask, :], I_U_q, U[mask, :]))

                if V_opt_q < V_pes:
                    mask[i] = 0

            # discard suboptimal queries
            X = X[mask, :]
            feedbacks = feedbacks[mask]
            p_est = p_est[mask]
            L = L[mask]
            U = U[mask]

            I_L = best_basis_h(X, L, self.H, h)
            I_U = best_basis_h(X, U, self.H, h)

            V_pes = np.min(self._get_weights_from_basis(X,I_L,L))
            V_opt = np.min(self._get_weights_from_basis(X,I_U,U))

            if V_pes == 0:
                ratio = 2+self.l # ensure no stopping
            else:
                ratio = V_opt / V_pes

            # round increment
            t = t + self.t_batch

            # Stopping condition
            if (X.shape[0] == (self.S * self.A - 1)) or \
                (ratio <= (1+self.l)) or \
                (t > self.t_expl_max):
                    break
            
        return X, p_est, feedbacks, t-1
    
    def _estimate_optimal_basis_prob_feed(self, X, p_est, F, h):
        W = np.zeros((self.S*self.A, self.S*self.A))

        for i in range(X.shape[0]):
            W[X[i,h], X[i,self.H+h]] = - (p_est[i] * (1 - p_est[i]))
            W[X[i,self.H+h], X[i,h]] = - (p_est[i] * (1 - p_est[i]))

        mst_est = minimum_spanning_tree(W).toarray()

        rows, cols = np.where(mst_est != 0)

        B = np.concatenate((rows.reshape(-1,1), cols.reshape(-1,1)), axis=1)
        B = np.sort(B, axis=1)

        basis = np.zeros((B.shape[0], X.shape[1]), dtype=int)
        basis[:,h] = B[:,0]
        basis[:,self.H+h] = B[:,1]

        return basis, self._get_weights_from_basis(X, basis, p_est), self._get_weights_from_basis(X, basis, F)
    
    def run(self):
        for h in range(self.H):
            Q = build_query_set(self.S, self.A, self.H, h)
            X, p_est, feedbacks, self.t_stop[h] = self._lambda_SQE(Q, h)

            self.basis[:,:,h], p_B, f_B = self._estimate_optimal_basis_prob_feed(X, p_est, feedbacks, h)

            beta_stop = self._beta(self.t_stop[h])

            var_B = np.min(np.multiply(p_B, 1-p_B))
            temp = (self.S*self.A + 1)**2 * (self.S*self.A - 1) / (np.pi**2 * self.e**2 * (var_B - beta_stop)**2)
            self.T[h] = int(np.ceil(self.sample_const * temp * np.log(2 * self.sample_const * temp / self.d)))

            t_remaining = self.T[h] - self.t_stop[h]

            for i in range(self.basis.shape[0]):
                pair = self.basis[i,:,h]
                traj1 = pair[:self.H]
                traj2 = pair[self.H:]

                f_B[i] = f_B[i] + self.labeler.generate_preference(traj1, traj2, n=t_remaining)

            p_B = f_B / self.T[h]

            delta_r = np.append(inverse_sigmoid(p_B), 0)
            Q = np.zeros((self.S*self.A, self.S*self.A), dtype=int)
            x = np.arange(self.S*self.A-1)
            y = self.basis[:,h,h]
            Q[x,y] = 1
            y = self.basis[:,self.H+h,h]
            Q[x,y] = -1
            Q[-1,-1] = 1

            r_est = np.linalg.solve(Q, delta_r).reshape(-1,1)
            r_est = r_est + np.abs(np.min(r_est)) + self.r_min
            self.rew_est[:, :, h] = r_est.reshape(self.S, self.A)