### procedure of the whole experiment in each single case.

import os
import sys
import time
import math
import numpy as np
import matplotlib.pyplot as plt
from numpy.linalg.linalg import matmul
from numpy.core.numeric import identity
import math
from scipy.special import gamma
from itertools import chain
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import copy
from torch.utils.data import DataLoader, TensorDataset

## In the experiment, all the components are sampled with same probability 1/K,
## In the experiment of mixture of regression with two components, the components are symmetric, i.e. beta_1^{*}=-beta_2^{*}
## The norm of compoenents is normalized by its l-2 norm
def get_comp(k,p):
  comp_beta = []
  ## get coefficient for MoR with two compoenents
  if k==2:
    true_coef_1 = torch.randn(p)
    true_coef_1 = 1*true_coef_1/torch.norm(true_coef_1)
    comp_beta.append(true_coef_1)
    true_coef_2 = -true_coef_1
    comp_beta.append(true_coef_2)
  else:
    for i in range(k):     ## get coefficient for MoR with K>=3 compoenents
      true_coef_1 = torch.randn(p)
      true_coef_1 = 1*true_coef_1/torch.norm(true_coef_1)
      comp_beta.append(true_coef_1)
  return comp_beta

## Get beta^OR by taking the weighted avergae of beta_j^{*} for j=1,...,K
## The component probability pi_j^{*} are the same.
def get_beta_OR(beta_comp_list):
  p=beta_comp_list[0].size(0)
  K=len(beta_comp_list)
  avg = torch.zeros(p)
  for i in range(K):
    avg += beta_comp_list[i]
  avg = avg/K
  return avg

## Get the minimum MSE using explicit expectation formula for MoR
def get_MSE_OR(beta_list,sigma2):
  beta_OR = get_beta_OR(beta_list)
  K= len(beta_list)
  MSE_OR = 0
  for i in range(K):
    MSE_OR = MSE_OR+ torch.linalg.vector_norm(beta_list[0]-beta_OR)**2
  MSE_OR = MSE_OR/K
  MSE_OR = MSE_OR+sigma2
  return MSE_OR

## Get input sequence of the attention layers H in R^{D*(n+1)}
def Input_Seq(B,TB,p,k,n_samples,sigma2,beta_list,D):
  prob_list = []
  for i in range(k):
    prob_list.append((i+1)/k)
  prompt_data = []
  for tau in range(B):
    X = torch.randn(n_samples, p)
    y = torch.zeros(n_samples, 1)
    x_query = torch.randn(1,p)
    comp_id = torch.rand(1)
    query_id = 0
    ## Determine the component for (x_query, y_query)
    for i in range(k):
      if (comp_id > prob_list[i]):
        query_id = query_id + 1
    y_query = x_query[0].dot(beta_list[query_id]) + np.sqrt(sigma2)*torch.randn(1)
    for i in range(n_samples):
      comp_id = torch.rand(1)
      query_id = 0
      for j in range(k):
      ## Determine the component of training samples in the prompt
        if (comp_id > prob_list[j]):
          query_id = query_id + 1
      y[i] = X[i].dot(beta_list[query_id]) + np.sqrt(sigma2)*torch.randn(1)
    E_tau_features = torch.cat([X, x_query], dim=0).t()
    y_tau_extended = torch.cat([y.t(), torch.zeros(1, 1)], dim=1)
    ## D=p+1, no special embeddings
    if D == p+1:
      E_tau = torch.cat([E_tau_features, y_tau_extended], dim=0)
    elif D == p+2: ## D=p+2, add the indicator of the training samples
      id_extended2 = torch.cat([torch.ones(1,n_samples), torch.zeros(1,1)], dim=1)
      E_tau = torch.cat([E_tau_features, id_extended2, y_tau_extended], dim=0)
    elif D > p+2:  ## D>p+2, add additional zeros in H
      id_extended = torch.cat([torch.zeros(D-p-2,n_samples), torch.zeros(D-p-2,1)], dim=1)
      id_extended2 = torch.cat([torch.ones(1,n_samples), torch.zeros(1,1)], dim=1)
      E_tau = torch.cat([E_tau_features, id_extended, id_extended2, y_tau_extended], dim=0)
    prompt_data.append((E_tau, y, y_query))
  test_data = []
  ## Same for the test prompt.
  for psai in range(TB):
    X_test = torch.randn(n_samples, p)
    y_test = torch.zeros(n_samples, 1)
    x_query_test = torch.randn(1,p)
    comp_id = torch.rand(1)
    query_id = 0
    for i in range(k):
      if (comp_id > prob_list[i]):
        query_id = query_id + 1
    y_query_test = x_query_test[0].dot(beta_list[query_id]) + np.sqrt(sigma2)*torch.randn(1)
    for i in range(n_samples):
      comp_id = torch.rand(1)
      query_id = 0
      for j in range(k):
        if (comp_id > prob_list[j]):
          query_id = query_id + 1
      y_test[i] = X_test[i].dot(beta_list[query_id]) + np.sqrt(sigma2)*torch.randn(1)
    E_test_features = torch.cat([X_test, x_query_test], dim=0).t()
    y_test_extended = torch.cat([y_test.t(), torch.zeros(1, 1)], dim=1)
    if D == p+1:
      E_test = torch.cat([E_test_features, y_test_extended], dim=0)
    elif D == p+2:
      id_extended2 = torch.cat([torch.ones(1,n_samples), torch.zeros(1,1)], dim=1)
      E_test = torch.cat([E_test_features, id_extended2, y_test_extended], dim=0)
    elif D > p+2:
      id_extended = torch.cat([torch.zeros(D-p-2,n_samples), torch.zeros(D-p-2,1)], dim=1)
      id_extended2 = torch.cat([torch.ones(1,n_samples), torch.zeros(1,1)], dim=1)
      E_test = torch.cat([E_test_features, id_extended, id_extended2, y_test_extended], dim=0)
    test_data.append((E_test, y_test, y_query_test))
  return prompt_data, test_data

## Squared Loss Function
def square_loss(y_pred, y_true):
    return (y_pred-y_true)**2


## Structure of multihead-attention layers and feedforward network
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, input_dim, M_heads, L):
        super(MultiHeadAttentionLayer, self).__init__()
        self.dk = input_dim
        self.M = M_heads
        self.W_PV = nn.ParameterList([nn.Parameter(torch.randn(input_dim, input_dim)/(L**2)) for _ in range(M_heads)])
        self.W_KQ = nn.ParameterList([nn.Parameter(torch.randn(input_dim, input_dim)/(L**2)) for _ in range(M_heads)])
        self.activation_fun = torch.nn.ReLU()

    def forward(self, E_tau):
        N_samp = E_tau.size(1)-1
        interaction = torch.zeros_like(E_tau)
        for i in range(self.M):
            PV_E = torch.mm(self.W_PV[i], E_tau)
            KQ_E = torch.mm(self.W_KQ[i], E_tau)/np.sqrt(self.dk)  ## attention score is normalized by square root of D
            activated = self.activation_fun(torch.mm(E_tau.t(), KQ_E))
            interaction += torch.mm(PV_E, activated)
        return E_tau + interaction / N_samp

class MultiLayerMultiHeadAttentionNetwork(nn.Module):
    def __init__(self, input_dim, M_heads, num_layers):
        super(MultiLayerMultiHeadAttentionNetwork, self).__init__()
        self.layers = nn.ModuleList([
            MultiHeadAttentionLayer(input_dim, M_heads, num_layers) for _ in range(num_layers)
        ])

    def forward(self, E_tau):
        x = E_tau
        for layer in self.layers:
            x = layer(x)
            #print(x)
        return x

def Exp_length_MSE(B,TB,length,p,D,prompt_data,test_data,epochs):
  ## B is the number of training batches or prompts
  ## TB is the number of testing batches or prompts
  n_samples_train=length     ## number of samples in the training prompt
  n_samples=length        ## number of samples in the testing prompt

  Mheads=4
  Layers=4
  TFmodel = MultiLayerMultiHeadAttentionNetwork(D,Mheads,Layers)
  optimizer = optim.Adam(TFmodel.parameters(), lr=0.00001)
  MSE_train_list = []
  MSE_test_list = []
  # Training loop over epochs
  for epoch in range(epochs):
    total_loss = 0.0
    TFmodel.train()
    for E_tau, y_tau, y_query in prompt_data:
        # Forward pass
        optimizer.zero_grad()
        E_tau_prime = TFmodel(E_tau)
        y_pred_query = E_tau_prime[p,n_samples_train]
        loss = square_loss(y_pred_query, y_query.item())

        # Accumulate loss and update correct predictions count
        total_loss += loss.item()
        # Backpropagation and optimization
        loss.backward()
        optimizer.step()

    test_loss = 0
    for E_psai, y_psai, y_q in test_data:
        with torch.no_grad():
            E_psai_prime = TFmodel(E_psai)
            y_pred_q = E_psai_prime[-1, -1]
            test_loss += (y_pred_q.item()-y_q.item())**2

    # Calculate and report average loss and accuracy for the epoch
    train_avg_loss = total_loss/B
    test_avg_loss = test_loss/TB
    MSE_train_list.append(train_avg_loss)
    MSE_test_list.append(test_avg_loss)
    print(f"Epoch {epoch+1}: Training MSE = {train_avg_loss:.4f}, Testing MSE = {test_avg_loss:.4f}")
  return MSE_test_list

rep=20   ## number of repeat
sigma2=0.01  ## variance of the noise
MSE_result_SNR10 = np.zeros((64,rep))
MSE_result_OR = np.zeros((64,rep))
epo=100000  ## number of training epoches
beta_comp_list = get_comp(k=5,p=32)
beta_OR = get_beta_OR(beta_comp_list)
for l in range(64):
  print(f"length {l+1}")
  for repeat in range(rep):
    print(f"repeat {repeat+1}")
    Prompt_dataset, Test_dataset = Input_Seq(64,64,32,5,64,0.01,beta_comp_list,64)
    result = Exp_length_MSE(64,64,l+1,32,64,Prompt_dataset,Test_dataset,epo)
    MSE_result_SNR10[l,repeat]=result[epo-1]