#### implements MultiLayer MultiHead Attention Network;


import os
import sys
import time
import math
import numpy as np
import matplotlib.pyplot as plt
from numpy.linalg.linalg import matmul
from numpy.core.numeric import identity
import math
from scipy.special import gamma
from itertools import chain
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import copy
from torch.utils.data import DataLoader, TensorDataset

## Squared Loss Function
def square_loss(y_pred, y_true):
    return (y_pred-y_true)**2


## Structure of multihead-attention layers and feedforward network
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, input_dim, M_heads, L):
        super(MultiHeadAttentionLayer, self).__init__()
        self.dk = input_dim
        self.M = M_heads
        self.W_PV = nn.ParameterList([nn.Parameter(torch.randn(input_dim, input_dim)/(L**2)) for _ in range(M_heads)])
        self.W_KQ = nn.ParameterList([nn.Parameter(torch.randn(input_dim, input_dim)/(L**2)) for _ in range(M_heads)])
        self.activation_fun = torch.nn.ReLU()

    def forward(self, E_tau):
        N_samp = E_tau.size(1)-1
        interaction = torch.zeros_like(E_tau)
        for i in range(self.M):
            PV_E = torch.mm(self.W_PV[i], E_tau)
            KQ_E = torch.mm(self.W_KQ[i], E_tau)/np.sqrt(self.dk)  ## attention score is normalized by square root of D
            activated = self.activation_fun(torch.mm(E_tau.t(), KQ_E))
            interaction += torch.mm(PV_E, activated)
        return E_tau + interaction / N_samp

class MultiLayerMultiHeadAttentionNetwork(nn.Module):
    def __init__(self, input_dim, M_heads, num_layers):
        super(MultiLayerMultiHeadAttentionNetwork, self).__init__()
        self.layers = nn.ModuleList([
            MultiHeadAttentionLayer(input_dim, M_heads, num_layers) for _ in range(num_layers)
        ])

    def forward(self, E_tau):
        x = E_tau
        for layer in self.layers:
            x = layer(x)
            #print(x)
        return x