import torch

class BatchEMMixtureRegression:
    def __init__(self, B, n, d, K, theta):
        """
        Initialize the Batch EM algorithm for Mixture of Regression problem.

        Parameters:
        B (int): Number of batches (prompts).
        n (int): Number of samples per batch.
        d (int): Dimension of input features.
        K (int): Number of mixture components.
        theta (float): Noise standard deviation.
        """
        self.B = B
        self.n = n
        self.d = d
        self.K = K
        self.theta = theta

        # Initialize pi uniformly on the simplex
        self.pi = self._sample_simplex(K)

        # Generate true regression coefficients for synthetic data
        self.true_beta = self._initialize_betas()

        # Initialize beta_j close to the true beta_j*
        self.beta = self._initialize_betas_close_to_true(self.true_beta)

        # Initialize gamma (responsibilities) as zeros
        self.gamma = torch.zeros(B, n, K)

    def _sample_simplex(self, K):
        """Draws a sample from the probability simplex in dimension K."""
        exp_samples = -torch.log(torch.rand(K))
        return exp_samples / exp_samples.sum()

    def _initialize_betas(self):
        """Initialize regression coefficients uniformly on the unit sphere."""
        beta = torch.randn(self.K, self.d)
        beta /= beta.norm(dim=1, keepdim=True)  # Normalize to unit norm
        return beta

    def _initialize_betas_close_to_true(self, true_beta):
        """Initialize regression coefficients close to the true coefficients with cos(angle) > 0.8."""
        beta = torch.zeros_like(true_beta)

        for j in range(self.K):
            while True:
                # Generate a random perturbation
                noise = 0.1 * torch.randn_like(true_beta[j])  # Small Gaussian noise
                beta_j = true_beta[j] + noise  # Add noise to the true beta
                beta_j /= beta_j.norm()  # Normalize to unit norm

                # Check cosine similarity condition
                cos_sim = torch.dot(beta_j, true_beta[j]) / (beta_j.norm() * true_beta[j].norm())

                if cos_sim > 0.8:
                    beta[j] = beta_j
                    break  # Accept this beta if condition is met

        return beta

    def generate_mixture_data(self):
        """Generate synthetic data where each sample is assigned to a mixture component."""
        true_pi = torch.full((self.K,), 1.0 / self.K)  # True mixing proportions
        true_beta = self._initialize_betas()  # True regression coefficients

        X_batches, Y_batches, labels = [], [], []

        for i in range(self.B):
            X_i = torch.randn(self.n, self.d)  # Generate feature matrix
            Y_i = torch.zeros(self.n)  # Initialize target values
            sample_labels = []

            for ell in range(self.n):
                # Sample a component index for each sample
                j = torch.multinomial(true_pi, 1).item()
                sample_labels.append(j)

                # Generate response variable y
                noise = self.theta * torch.randn(1)
                Y_i[ell] = X_i[ell] @ true_beta[j] + noise

            X_batches.append(X_i)
            Y_batches.append(Y_i)
            labels.append(sample_labels)

        return X_batches, Y_batches, labels, true_beta

    def e_step(self, X, Y):
        """Expectation step: Update assignment probabilities gamma."""
        for i in range(self.B):
            for ell in range(self.n):
                likelihoods = torch.zeros(self.K)
                for j in range(self.K):
                    residual = Y[i][ell] - X[i][ell] @ self.beta[j]  # Compute residual
                    likelihoods[j] = torch.exp(-residual ** 2 / (2 * self.theta ** 2))

                self.gamma[i, ell] = self.pi * likelihoods
                self.gamma[i, ell] /= self.gamma[i, ell].sum()  # Normalize responsibilities

    def m_step(self, X, Y):
        """Maximization step: Update mixture weights pi and regression coefficients beta."""
        # Update pi
        self.pi = self.gamma.mean(dim=(0, 1))

        # Update beta_j via weighted least squares
        for j in range(self.K):
            weighted_XTX = torch.zeros(self.d, self.d)
            weighted_XTY = torch.zeros(self.d)

            for i in range(self.B):
                for ell in range(self.n):
                    X_ell = X[i][ell].unsqueeze(0)  # Shape: (1, d)
                    Y_ell = Y[i][ell]  # Scalar value
                    weight = self.gamma[i, ell, j].item()

                    weighted_XTX += weight * (X_ell.T @ X_ell)  # Shape: (d, d)
                    weighted_XTY += weight * (X_ell.T * Y_ell).squeeze()  # Shape: (d,)

            # Solve the least squares problem
            if torch.linalg.det(weighted_XTX) > 1e-6:  # Ensure invertibility
                self.beta[j] = torch.linalg.solve(weighted_XTX, weighted_XTY)
            else:
                self.beta[j] = torch.linalg.lstsq(weighted_XTX, weighted_XTY.unsqueeze(1))[0].squeeze()

    def fit(self, X, Y, max_iters=100, tol=1e-4):
        """
        Fit the Mixture of Regression Model using Batch EM.

        Parameters:
        X (list of B tensors): Feature matrices of shape (n, d).
        Y (list of B tensors): Target vectors of shape (n,).
        max_iters (int): Maximum number of EM iterations.
        tol (float): Convergence threshold for change in beta.
        """
        for t in range(max_iters):
            beta_old = self.beta.clone()

            self.e_step(X, Y)  # E-step
            self.m_step(X, Y)  # M-step

            # Check convergence
            beta_change = torch.norm(self.beta - beta_old, p='fro')
            if beta_change < tol:
                break

        return self.beta

# Specification of parameters
B, n, d, K, theta = 64, 50, 32, 3, 1.0

# Initialize model
model = BatchEMMixtureRegression(B, n, d, K, theta)

# Generate data from the mixture of regressions (each sample is assigned to a component)
X, Y, true_labels, true_beta = model.generate_mixture_data()

# Fit the model using Batch EM
final_betas = model.fit(X, Y)

def evaluate_mse_weighted(model, X_test, Y_test):
    """
    Evaluate the MSE of the predictions using the weighted beta estimate:
    \hat{beta} = sum(pi_i * beta_i) from Batch EM.

    Parameters:
    model (BatchEMMixtureRegression): Trained model containing estimated betas and pis.
    X_test (list of tensors): Test feature matrices, each of shape (n_test, d).
    Y_test (list of tensors): True target values, each of shape (n_test,).

    Returns:
    float: Mean Squared Error (MSE) of the predictions.
    """
    # Compute the weighted beta
    beta_weighted = (model.pi.unsqueeze(1) * model.beta).sum(dim=0)  # Shape: (d,)

    total_mse = 0.0
    total_samples = 0

    for i in range(len(X_test)):
        X_i = X_test[i]  # Test features (n_test, d)
        Y_i = Y_test[i]  # True responses (n_test,)

        # Predict y_hat = X @ beta_weighted
        Y_pred = X_i @ beta_weighted  # Shape: (n_test,)

        # Compute squared errors
        mse_i = ((Y_pred - Y_i) ** 2).sum().item()
        total_mse += mse_i
        total_samples += len(Y_i)

    return total_mse / total_samples  # Return mean squared error

# Generate test data using the same true model
X_test, Y_test, _, _ = model.generate_mixture_data()

# Evaluate MSE on test data
mse_weighted = evaluate_mse_weighted(model, X_test, Y_test)

print(f"Mean Squared Error (MSE) using weighted beta: {mse_weighted:.4f}")