### input sequence generator for mixture models;


import os
import sys
import time
import math
import numpy as np
import matplotlib.pyplot as plt
from numpy.linalg.linalg import matmul
from numpy.core.numeric import identity
import math
from scipy.special import gamma
from itertools import chain
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import copy
from torch.utils.data import DataLoader, TensorDataset

## In the experiment, all the components are sampled with same probability 1/K,
## In the experiment of mixture of regression with two components, the components are symmetric, i.e. beta_1^{*}=-beta_2^{*}
## The norm of compoenents is normalized by its l-2 norm
def get_comp(k,p):
  comp_beta = []
  ## get coefficient for MoR with two compoenents
  if k==2:
    true_coef_1 = torch.randn(p)
    true_coef_1 = 1*true_coef_1/torch.norm(true_coef_1)
    comp_beta.append(true_coef_1)
    true_coef_2 = -true_coef_1
    comp_beta.append(true_coef_2)
  else:
    for i in range(k):     ## get coefficient for MoR with K>=3 compoenents
      true_coef_1 = torch.randn(p)
      true_coef_1 = 1*true_coef_1/torch.norm(true_coef_1)
      comp_beta.append(true_coef_1)
  return comp_beta

## Get beta^OR by taking the weighted avergae of beta_j^{*} for j=1,...,K
## The component probability pi_j^{*} are the same.
def get_beta_OR(beta_comp_list):
  p=beta_comp_list[0].size(0)
  K=len(beta_comp_list)
  avg = torch.zeros(p)
  for i in range(K):
    avg += beta_comp_list[i]
  avg = avg/K
  return avg

## Get the minimum MSE using explicit expectation formula for MoR
def get_MSE_OR(beta_list,sigma2):
  beta_OR = get_beta_OR(beta_list)
  K= len(beta_list)
  MSE_OR = 0
  for i in range(K):
    MSE_OR = MSE_OR+ torch.linalg.vector_norm(beta_list[0]-beta_OR)**2
  MSE_OR = MSE_OR/K
  MSE_OR = MSE_OR+sigma2
  return MSE_OR

## Get input sequence of the attention layers H in R^{D*(n+1)}
def Input_Seq(B,TB,p,k,n_samples,sigma2,beta_list,D):
  prob_list = []
  for i in range(k):
    prob_list.append((i+1)/k)
  prompt_data = []
  for tau in range(B):
    X = torch.randn(n_samples, p)
    y = torch.zeros(n_samples, 1)
    x_query = torch.randn(1,p)
    comp_id = torch.rand(1)
    query_id = 0
    ## Determine the component for (x_query, y_query)
    for i in range(k):
      if (comp_id > prob_list[i]):
        query_id = query_id + 1
    y_query = x_query[0].dot(beta_list[query_id]) + np.sqrt(sigma2)*torch.randn(1)
    for i in range(n_samples):
      comp_id = torch.rand(1)
      query_id = 0
      for j in range(k):
      ## Determine the component of training samples in the prompt
        if (comp_id > prob_list[j]):
          query_id = query_id + 1
      y[i] = X[i].dot(beta_list[query_id]) + np.sqrt(sigma2)*torch.randn(1)
    E_tau_features = torch.cat([X, x_query], dim=0).t()
    y_tau_extended = torch.cat([y.t(), torch.zeros(1, 1)], dim=1)
    ## D=p+1, no special embeddings
    if D == p+1:
      E_tau = torch.cat([E_tau_features, y_tau_extended], dim=0)
    elif D == p+2: ## D=p+2, add the indicator of the training samples
      id_extended2 = torch.cat([torch.ones(1,n_samples), torch.zeros(1,1)], dim=1)
      E_tau = torch.cat([E_tau_features, id_extended2, y_tau_extended], dim=0)
    elif D > p+2:  ## D>p+2, add additional zeros in H
      id_extended = torch.cat([torch.zeros(D-p-2,n_samples), torch.zeros(D-p-2,1)], dim=1)
      id_extended2 = torch.cat([torch.ones(1,n_samples), torch.zeros(1,1)], dim=1)
      E_tau = torch.cat([E_tau_features, id_extended, id_extended2, y_tau_extended], dim=0)
    prompt_data.append((E_tau, y, y_query))
  test_data = []
  ## Same for the test prompt.
  for psai in range(TB):
    X_test = torch.randn(n_samples, p)
    y_test = torch.zeros(n_samples, 1)
    x_query_test = torch.randn(1,p)
    comp_id = torch.rand(1)
    query_id = 0
    for i in range(k):
      if (comp_id > prob_list[i]):
        query_id = query_id + 1
    y_query_test = x_query_test[0].dot(beta_list[query_id]) + np.sqrt(sigma2)*torch.randn(1)
    for i in range(n_samples):
      comp_id = torch.rand(1)
      query_id = 0
      for j in range(k):
        if (comp_id > prob_list[j]):
          query_id = query_id + 1
      y_test[i] = X_test[i].dot(beta_list[query_id]) + np.sqrt(sigma2)*torch.randn(1)
    E_test_features = torch.cat([X_test, x_query_test], dim=0).t()
    y_test_extended = torch.cat([y_test.t(), torch.zeros(1, 1)], dim=1)
    if D == p+1:
      E_test = torch.cat([E_test_features, y_test_extended], dim=0)
    elif D == p+2:
      id_extended2 = torch.cat([torch.ones(1,n_samples), torch.zeros(1,1)], dim=1)
      E_test = torch.cat([E_test_features, id_extended2, y_test_extended], dim=0)
    elif D > p+2:
      id_extended = torch.cat([torch.zeros(D-p-2,n_samples), torch.zeros(D-p-2,1)], dim=1)
      id_extended2 = torch.cat([torch.ones(1,n_samples), torch.zeros(1,1)], dim=1)
      E_test = torch.cat([E_test_features, id_extended, id_extended2, y_test_extended], dim=0)
    test_data.append((E_test, y_test, y_query_test))
  return prompt_data, test_data