import os
import sys
import time
import math
import numpy as np
import matplotlib.pyplot as plt
from numpy.linalg.linalg import matmul
from numpy.core.numeric import identity
import math
from scipy.special import gamma
from itertools import chain
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import copy
from torch.utils.data import DataLoader, TensorDataset

## In the experiment, all the components are sampled with same probability 1/K,
## In the experiment of mixture of regression with two components, the components are symmetric, i.e. beta_1^{*}=-beta_2^{*}
## The norm of compoenents is normalized by its l-2 norm
def get_comp(k,p):
  comp_beta = []
  ## get coefficient for MoR with two compoenents
  if k==2:
    true_coef_1 = torch.randn(p)
    true_coef_1 = 1*true_coef_1/torch.norm(true_coef_1)
    comp_beta.append(true_coef_1)
    true_coef_2 = -true_coef_1
    comp_beta.append(true_coef_2)
  else:
    for i in range(k):     ## get coefficient for MoR with K>=3 compoenents
      true_coef_1 = torch.randn(p)
      true_coef_1 = 1*true_coef_1/torch.norm(true_coef_1)
      comp_beta.append(true_coef_1)
  return comp_beta

## Get beta^OR by taking the weighted avergae of beta_j^{*} for j=1,...,K
## The component probability pi_j^{*} are the same.
def get_beta_OR(beta_comp_list):
  p=beta_comp_list[0].size(0)
  K=len(beta_comp_list)
  avg = torch.zeros(p)
  for i in range(K):
    avg += beta_comp_list[i]
  avg = avg/K
  return avg

## Get the minimum MSE using explicit expectation formula for MoR
def get_MSE_OR(beta_list,sigma2):
  beta_OR = get_beta_OR(beta_list)
  K= len(beta_list)
  MSE_OR = 0
  for i in range(K):
    MSE_OR = MSE_OR+ torch.linalg.vector_norm(beta_list[0]-beta_OR)**2
  MSE_OR = MSE_OR/K
  MSE_OR = MSE_OR+sigma2
  return MSE_OR

## Get input sequence of the attention layers H in R^{D*(n+1)}
def Input_Seq(B,TB,p,k,n_samples,sigma2,beta_list,D):
  prob_list = []
  for i in range(k):
    prob_list.append((i+1)/k)
  prompt_data = []
  for tau in range(B):
    X = torch.randn(n_samples, p)
    y = torch.zeros(n_samples, 1)
    x_query = torch.randn(1,p)
    comp_id = torch.rand(1)
    query_id = 0
    ## Determine the component for (x_query, y_query)
    for i in range(k):
      if (comp_id > prob_list[i]):
        query_id = query_id + 1
    y_query = x_query[0].dot(beta_list[query_id]) + np.sqrt(sigma2)*torch.randn(1)
    for i in range(n_samples):
      comp_id = torch.rand(1)
      query_id = 0
      for j in range(k):
      ## Determine the component of training samples in the prompt
        if (comp_id > prob_list[j]):
          query_id = query_id + 1
      y[i] = X[i].dot(beta_list[query_id]) + np.sqrt(sigma2)*torch.randn(1)
    E_tau_features = torch.cat([X, x_query], dim=0).t()
    y_tau_extended = torch.cat([y.t(), torch.zeros(1, 1)], dim=1)
    ## D=p+1, no special embeddings
    if D == p+1:
      E_tau = torch.cat([E_tau_features, y_tau_extended], dim=0)
    elif D == p+2: ## D=p+2, add the indicator of the training samples
      id_extended2 = torch.cat([torch.ones(1,n_samples), torch.zeros(1,1)], dim=1)
      E_tau = torch.cat([E_tau_features, id_extended2, y_tau_extended], dim=0)
    elif D > p+2:  ## D>p+2, add additional zeros in H
      id_extended = torch.cat([torch.zeros(D-p-2,n_samples), torch.zeros(D-p-2,1)], dim=1)
      id_extended2 = torch.cat([torch.ones(1,n_samples), torch.zeros(1,1)], dim=1)
      E_tau = torch.cat([E_tau_features, id_extended, id_extended2, y_tau_extended], dim=0)
    prompt_data.append((E_tau, y, y_query))
  test_data = []
  ## Same for the test prompt.
  for psai in range(TB):
    X_test = torch.randn(n_samples, p)
    y_test = torch.zeros(n_samples, 1)
    x_query_test = torch.randn(1,p)
    comp_id = torch.rand(1)
    query_id = 0
    for i in range(k):
      if (comp_id > prob_list[i]):
        query_id = query_id + 1
    y_query_test = x_query_test[0].dot(beta_list[query_id]) + np.sqrt(sigma2)*torch.randn(1)
    for i in range(n_samples):
      comp_id = torch.rand(1)
      query_id = 0
      for j in range(k):
        if (comp_id > prob_list[j]):
          query_id = query_id + 1
      y_test[i] = X_test[i].dot(beta_list[query_id]) + np.sqrt(sigma2)*torch.randn(1)
    E_test_features = torch.cat([X_test, x_query_test], dim=0).t()
    y_test_extended = torch.cat([y_test.t(), torch.zeros(1, 1)], dim=1)
    if D == p+1:
      E_test = torch.cat([E_test_features, y_test_extended], dim=0)
    elif D == p+2:
      id_extended2 = torch.cat([torch.ones(1,n_samples), torch.zeros(1,1)], dim=1)
      E_test = torch.cat([E_test_features, id_extended2, y_test_extended], dim=0)
    elif D > p+2:
      id_extended = torch.cat([torch.zeros(D-p-2,n_samples), torch.zeros(D-p-2,1)], dim=1)
      id_extended2 = torch.cat([torch.ones(1,n_samples), torch.zeros(1,1)], dim=1)
      E_test = torch.cat([E_test_features, id_extended, id_extended2, y_test_extended], dim=0)
    test_data.append((E_test, y_test, y_query_test))
  return prompt_data, test_data


class BatchEMMixtureRegression:
    def __init__(self, B, n, d, K, theta):
        """
        Initialize the Batch EM algorithm for Mixture of Regression problem.

        Parameters:
        B (int): Number of batches (prompts).
        n (int): Number of samples per batch.
        d (int): Dimension of input features.
        K (int): Number of mixture components.
        theta (float): Noise standard deviation.
        """
        self.B = B
        self.n = n
        self.d = d
        self.K = K
        self.theta = theta

        # Initialize pi uniformly on the simplex
        self.pi = self._sample_simplex(K)

        # Generate true regression coefficients for synthetic data
        self.true_beta = self._initialize_betas()

        # Initialize beta_j close to the true beta_j*
        self.beta = self._initialize_betas_close_to_true(self.true_beta)

        # Initialize gamma (responsibilities) as zeros
        self.gamma = torch.zeros(B, n, K)

    def _sample_simplex(self, K):
        """Draws a sample from the probability simplex in dimension K."""
        exp_samples = -torch.log(torch.rand(K))
        return exp_samples / exp_samples.sum()

    def _initialize_betas(self):
        """Initialize regression coefficients uniformly on the unit sphere."""
        beta = torch.randn(self.K, self.d)
        beta /= beta.norm(dim=1, keepdim=True)  # Normalize to unit norm
        return beta

    def _initialize_betas_close_to_true(self, true_beta):
        """Initialize regression coefficients close to the true coefficients with cos(angle) > 0.8."""
        beta = torch.zeros_like(true_beta)

        for j in range(self.K):
            while True:
                # Generate a random perturbation
                noise = 0.1 * torch.randn_like(true_beta[j])  # Small Gaussian noise
                beta_j = true_beta[j] + noise  # Add noise to the true beta
                beta_j /= beta_j.norm()  # Normalize to unit norm

                # Check cosine similarity condition
                cos_sim = torch.dot(beta_j, true_beta[j]) / (beta_j.norm() * true_beta[j].norm())

                if cos_sim > 0.8:
                    beta[j] = beta_j
                    break  # Accept this beta if condition is met

        return beta

    def generate_mixture_data(self):
        """Generate synthetic data where each sample is assigned to a mixture component."""
        true_pi = torch.full((self.K,), 1.0 / self.K)  # True mixing proportions
        true_beta = self._initialize_betas()  # True regression coefficients

        X_batches, Y_batches, labels = [], [], []

        for i in range(self.B):
            X_i = torch.randn(self.n, self.d)  # Generate feature matrix
            Y_i = torch.zeros(self.n)  # Initialize target values
            sample_labels = []

            for ell in range(self.n):
                # Sample a component index for each sample
                j = torch.multinomial(true_pi, 1).item()
                sample_labels.append(j)

                # Generate response variable y
                noise = self.theta * torch.randn(1)
                Y_i[ell] = X_i[ell] @ true_beta[j] + noise

            X_batches.append(X_i)
            Y_batches.append(Y_i)
            labels.append(sample_labels)

        return X_batches, Y_batches, labels, true_beta

    def e_step(self, X, Y):
        """Expectation step: Update assignment probabilities gamma."""
        for i in range(self.B):
            for ell in range(self.n):
                likelihoods = torch.zeros(self.K)
                for j in range(self.K):
                    residual = Y[i][ell] - X[i][ell] @ self.beta[j]  # Compute residual
                    likelihoods[j] = torch.exp(-residual ** 2 / (2 * self.theta ** 2))

                self.gamma[i, ell] = self.pi * likelihoods
                self.gamma[i, ell] /= self.gamma[i, ell].sum()  # Normalize responsibilities

    def m_step(self, X, Y):
        """Maximization step: Update mixture weights pi and regression coefficients beta."""
        # Update pi
        self.pi = self.gamma.mean(dim=(0, 1))

        # Update beta_j via weighted least squares
        for j in range(self.K):
            weighted_XTX = torch.zeros(self.d, self.d)
            weighted_XTY = torch.zeros(self.d)

            for i in range(self.B):
                for ell in range(self.n):
                    X_ell = X[i][ell].unsqueeze(0)  # Shape: (1, d)
                    Y_ell = Y[i][ell]  # Scalar value
                    weight = self.gamma[i, ell, j].item()

                    weighted_XTX += weight * (X_ell.T @ X_ell)  # Shape: (d, d)
                    weighted_XTY += weight * (X_ell.T * Y_ell).squeeze()  # Shape: (d,)

            # Solve the least squares problem
            if torch.linalg.det(weighted_XTX) > 1e-6:  # Ensure invertibility
                self.beta[j] = torch.linalg.solve(weighted_XTX, weighted_XTY)
            else:
                self.beta[j] = torch.linalg.lstsq(weighted_XTX, weighted_XTY.unsqueeze(1))[0].squeeze()

    def fit(self, X, Y, max_iters=100, tol=1e-4):
        """
        Fit the Mixture of Regression Model using Batch EM.

        Parameters:
        X (list of B tensors): Feature matrices of shape (n, d).
        Y (list of B tensors): Target vectors of shape (n,).
        max_iters (int): Maximum number of EM iterations.
        tol (float): Convergence threshold for change in beta.
        """
        for t in range(max_iters):
            beta_old = self.beta.clone()

            self.e_step(X, Y)  # E-step
            self.m_step(X, Y)  # M-step

            # Check convergence
            beta_change = torch.norm(self.beta - beta_old, p='fro')
            if beta_change < tol:
                break

        return self.beta

# Specification of parameters
B, n, d, K, theta = 64, 50, 32, 3, 1.0

# Initialize model
model = BatchEMMixtureRegression(B, n, d, K, theta)

# Generate data from the mixture of regressions (each sample is assigned to a component)
X, Y, true_labels, true_beta = model.generate_mixture_data()

# Fit the model using Batch EM
final_betas = model.fit(X, Y)

def evaluate_mse_weighted(model, X_test, Y_test):
    """
    Evaluate the MSE of the predictions using the weighted beta estimate:
    \hat{beta} = sum(pi_i * beta_i) from Batch EM.

    Parameters:
    model (BatchEMMixtureRegression): Trained model containing estimated betas and pis.
    X_test (list of tensors): Test feature matrices, each of shape (n_test, d).
    Y_test (list of tensors): True target values, each of shape (n_test,).

    Returns:
    float: Mean Squared Error (MSE) of the predictions.
    """
    # Compute the weighted beta
    beta_weighted = (model.pi.unsqueeze(1) * model.beta).sum(dim=0)  # Shape: (d,)

    total_mse = 0.0
    total_samples = 0

    for i in range(len(X_test)):
        X_i = X_test[i]  # Test features (n_test, d)
        Y_i = Y_test[i]  # True responses (n_test,)

        # Predict y_hat = X @ beta_weighted
        Y_pred = X_i @ beta_weighted  # Shape: (n_test,)

        # Compute squared errors
        mse_i = ((Y_pred - Y_i) ** 2).sum().item()
        total_mse += mse_i
        total_samples += len(Y_i)

    return total_mse / total_samples  # Return mean squared error

# Generate test data using the same true model
X_test, Y_test, _, _ = model.generate_mixture_data()

# Evaluate MSE on test data
mse_weighted = evaluate_mse_weighted(model, X_test, Y_test)

print(f"Mean Squared Error (MSE) using weighted beta: {mse_weighted:.4f}")

## Squared Loss Function
def square_loss(y_pred, y_true):
    return (y_pred-y_true)**2


## Structure of multihead-attention layers and feedforward network
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, input_dim, M_heads, L, hidden_D):
        super(MultiHeadAttentionLayer, self).__init__()
        self.dk = input_dim
        self.M = M_heads
        self.W_PV = nn.ParameterList([nn.Parameter(torch.randn(input_dim, input_dim)/(L**2)) for _ in range(M_heads)])
        self.W_KQ = nn.ParameterList([nn.Parameter(torch.randn(input_dim, input_dim)/(L**2)) for _ in range(M_heads)])
        self.W1 = nn.Parameter(torch.randn(hidden_D, input_dim)/(L**2))
        self.W2 = nn.Parameter(torch.randn(input_dim,hidden_D)/(L**2))
        self.activation_fun = torch.nn.ReLU()

    def forward(self, E_tau):
        N_samp = E_tau.size(1)-1
        interaction = torch.zeros_like(E_tau)
        for i in range(self.M):
            PV_E = torch.mm(self.W_PV[i], E_tau)
            KQ_E = torch.mm(self.W_KQ[i], E_tau)/np.sqrt(self.dk)  ## attention score is normalized by square root of D
            activated = self.activation_fun(torch.mm(E_tau.t(), KQ_E))
            interaction += torch.mm(PV_E, activated)
        E_att = E_tau + interaction / N_samp
        activated_W1H = self.activation_fun(torch.mm(self.W1,E_att))
        E_mlp = E_att + torch.mm(self.W2,activated_W1H)
        return E_mlp

class MultiLayerMultiHeadAttentionNetwork(nn.Module):
    def __init__(self, input_dim, M_heads, num_layers, num_Dp):
        super(MultiLayerMultiHeadAttentionNetwork, self).__init__()
        self.layers = nn.ModuleList([
            MultiHeadAttentionLayer(input_dim, M_heads, num_layers, num_Dp) for _ in range(num_layers)
        ])

    def forward(self, E_tau):
        x = E_tau
        for layer in self.layers:
            x = layer(x)
            #print(x)
        return x

def Exp_length_MSE(B,TB,length,p,D,prompt_data,test_data,epochs):
  ## B is the number of training batches or prompts
  ## TB is the number of testing batches or prompts
  n_samples_train=length     ## number of samples in the training prompt
  n_samples=length        ## number of samples in the testing prompt

  Mheads=4
  Layers=4
  hidden_Dim=34
  TFmodel = MultiLayerMultiHeadAttentionNetwork(D,Mheads,Layers,hidden_Dim)
  optimizer = optim.Adam(TFmodel.parameters(), lr=0.00001)
  MSE_train_list = []
  MSE_test_list = []
  # Training loop over epochs
  for epoch in range(epochs):
    total_loss = 0.0
    TFmodel.train()
    for E_tau, y_tau, y_query in prompt_data:
        # Forward pass
        optimizer.zero_grad()
        E_tau_prime = TFmodel(E_tau)
        y_pred_query = E_tau_prime[p,n_samples_train]
        loss = square_loss(y_pred_query, y_query.item())

        # Accumulate loss and update correct predictions count
        total_loss += loss.item()
        # Backpropagation and optimization
        loss.backward()
        optimizer.step()

    test_loss = 0
    for E_psai, y_psai, y_q in test_data:
        with torch.no_grad():
            E_psai_prime = TFmodel(E_psai)
            y_pred_q = E_psai_prime[-1, -1]
            test_loss += (y_pred_q.item()-y_q.item())**2

    # Calculate and report average loss and accuracy for the epoch
    train_avg_loss = total_loss/B
    test_avg_loss = test_loss/TB
    MSE_train_list.append(train_avg_loss)
    MSE_test_list.append(test_avg_loss)
    print(f"Epoch {epoch+1}: Training MSE = {train_avg_loss:.4f}, Testing MSE = {test_avg_loss:.4f}")
  return MSE_test_list

rep=20   ## number of repeat
sigma2=0.01  ## variance of the noise
MSE_result_SNR10 = np.zeros((64,rep))
MSE_result_OR = np.zeros((64,rep))
epo=100000  ## number of training epoches
beta_comp_list = get_comp(k=5,p=32)
beta_OR = get_beta_OR(beta_comp_list)
for l in range(64):
  print(f"length {l+1}")
  for repeat in range(rep):
    print(f"repeat {repeat+1}")
    Prompt_dataset, Test_dataset = Input_Seq(64,64,32,5,64,0.01,beta_comp_list,64)
    result = Exp_length_MSE(64,64,l+1,32,64,Prompt_dataset,Test_dataset,epo)
    MSE_result_SNR10[l,repeat]=result[epo-1]