import numpy as np

class OLSEstimator():
    def __init__(self, EXP):
        self.EXP = EXP
        self.K = EXP.K
        self.M = EXP.M
        self.d = EXP.d
        self.phi = EXP.phi
        self.Phi = self.phi.reshape(-1, self.phi.shape[-1])
        self.design_phi = EXP.design_phi
        self.Temp = np.linalg.inv(self.design_phi)
        self.theta_hat = np.ones(self.d)
        self.beta_hat = np.ones(self.d)

    def updateModel(self, F_esti, G_esti):
        m_indices, k_indices = zip(*self.EXP.design_indices)
        F_design_elements = F_esti[m_indices, k_indices]
        F_design_esti = F_design_elements.reshape(-1, 1)
        G_design_elements = G_esti[m_indices, k_indices]
        G_design_esti = G_design_elements.reshape(-1, 1)
        self.theta_hat = self.Temp @ F_design_esti
        self.beta_hat = self.Temp @ G_design_esti

    def predict(self):
        f_esti = self.phi @ self.theta_hat.ravel()
        g_esti = self.phi @ self.beta_hat.ravel()
        return f_esti, g_esti

    def compute_Z(self):
        Z_mat = np.zeros((self.M, self.K, self.d))
        for j in range(self.M):
            for i in range(self.K):
                Z_mat[j, i] = self.phi[j, i].T @ self.Temp
        return Z_mat, self.Temp

    def compute_regression_variance(self, alternative_count, opt_solution):

        m_indices, k_indices = zip(*self.EXP.design_indices)
        design_var = self.EXP.variance[m_indices, k_indices]
        design_count = alternative_count[m_indices, k_indices]
        design_ratio = design_count / np.sum(design_count)
        Sigma = np.diag(design_var / design_ratio)
        theta_var = self.Temp @ Sigma @ self.Temp.T
        phi_reshaped = self.phi.reshape(-1, self.phi.shape[-1])
        temp = phi_reshaped @ theta_var @ phi_reshaped.T
        reg_variance = np.diag(temp).reshape(self.phi.shape[0], self.phi.shape[1])

        if np.any(opt_solution == -1):
            diff_reg_variance = None
        else:
            ref_phi = self.phi[np.arange(self.M), opt_solution, :]
            ref_phi_reshaped = ref_phi[:, np.newaxis, :]
            diff_phi = self.phi - ref_phi_reshaped
            diff_phi_reshaped = diff_phi.reshape(-1, diff_phi.shape[-1])
            temp = diff_phi_reshaped @ theta_var @ diff_phi_reshaped.T
            diff_reg_variance = np.diag(temp).reshape(diff_phi.shape[0], diff_phi.shape[1])

        return reg_variance, diff_reg_variance
