### training loops for models

import os
import sys
import time
import math
import numpy as np
import matplotlib.pyplot as plt
from numpy.linalg.linalg import matmul
from numpy.core.numeric import identity
import math
from scipy.special import gamma
from itertools import chain
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import copy
from torch.utils.data import DataLoader, TensorDataset

def Exp_length_MSE(B,TB,length,p,D,prompt_data,test_data,epochs):
  ## B is the number of training batches or prompts
  ## TB is the number of testing batches or prompts
  n_samples_train=length     ## number of samples in the training prompt
  n_samples=length        ## number of samples in the testing prompt

  Mheads=4
  Layers=4
  TFmodel = MultiLayerMultiHeadAttentionNetwork(D,Mheads,Layers)
  optimizer = optim.Adam(TFmodel.parameters(), lr=0.00001)
  MSE_train_list = []
  MSE_test_list = []
  # Training loop over epochs
  for epoch in range(epochs):
    total_loss = 0.0
    TFmodel.train()
    for E_tau, y_tau, y_query in prompt_data:
        # Forward pass
        optimizer.zero_grad()
        E_tau_prime = TFmodel(E_tau)
        y_pred_query = E_tau_prime[p,n_samples_train]
        loss = square_loss(y_pred_query, y_query.item())

        # Accumulate loss and update correct predictions count
        total_loss += loss.item()
        # Backpropagation and optimization
        loss.backward()
        optimizer.step()

    test_loss = 0
    for E_psai, y_psai, y_q in test_data:
        with torch.no_grad():
            E_psai_prime = TFmodel(E_psai)
            y_pred_q = E_psai_prime[-1, -1]
            test_loss += (y_pred_q.item()-y_q.item())**2

    # Calculate and report average loss and accuracy for the epoch
    train_avg_loss = total_loss/B
    test_avg_loss = test_loss/TB
    MSE_train_list.append(train_avg_loss)
    MSE_test_list.append(test_avg_loss)
    print(f"Epoch {epoch+1}: Training MSE = {train_avg_loss:.4f}, Testing MSE = {test_avg_loss:.4f}")
  return MSE_test_list

rep=20   ## number of repeat
sigma2=0.01  ## variance of the noise
MSE_result_SNR10 = np.zeros((64,rep))
MSE_result_OR = np.zeros((64,rep))
epo=100000  ## number of training epoches
beta_comp_list = get_comp(k=5,p=32)
beta_OR = get_beta_OR(beta_comp_list)
for l in range(64):
  print(f"length {l+1}")
  for repeat in range(rep):
    print(f"repeat {repeat+1}")
    Prompt_dataset, Test_dataset = Input_Seq(64,64,32,5,64,0.01,beta_comp_list,64)
    result = Exp_length_MSE(64,64,l+1,32,64,Prompt_dataset,Test_dataset,epo)
    MSE_result_SNR10[l,repeat]=result[epo-1]