import torch.nn as nn
import torch


class AdaptiveProjector(nn.Module):
    def __init__(self, num_heads, head_dim, hidden_dim, low_dim=None, qkv=True, num_proj=4):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.hidden_dim = hidden_dim
        self.low_dim = low_dim
        # import IPython;IPython.embed()
        # self.complete_regularizing = complete_regularizing
        self.qkv=qkv
        if low_dim is None:
            tmp = []
            tmp.append(nn.Parameter(torch.zeros(self.num_heads, self.head_dim, self.hidden_dim, dtype=torch.float32, requires_grad=True)))
            for _ in range(num_proj - 2):
                tmp.append(nn.Parameter(torch.zeros(self.num_heads, self.hidden_dim, self.hidden_dim, dtype=torch.float32, requires_grad=True)))
            tmp.append(nn.Parameter(torch.zeros(self.num_heads, self.hidden_dim, self.head_dim, dtype=torch.float32, requires_grad=True)))
            self.proj = nn.ParameterList(tmp)
            # self.proj = nn.ParameterList([nn.Parameter(torch.zeros(self.num_heads, self.head_dim, self.head_dim, dtype=torch.float32, requires_grad=True)) for _ in range(num_proj)])
            self.weight_init()
        # elif low_dim == 128:
        #     self.proj = nn.ParameterList([nn.Parameter(torch.zeros(self.num_heads, self.head_dim, self.head_dim, dtype=torch.float32, requires_grad=True)) for _ in range(num_proj - 1)])
        else:
            if qkv:
                tmp = []
                tmp.append(nn.Parameter(torch.zeros(self.num_heads, self.head_dim, self.hidden_dim, dtype=torch.float32, requires_grad=True)))
                for _ in range(num_proj - 2):
                    tmp.append(nn.Parameter(torch.zeros(self.num_heads, self.hidden_dim, self.hidden_dim, dtype=torch.float32, requires_grad=True)))
                tmp.append(nn.Parameter(torch.zeros(self.num_heads, self.hidden_dim, self.low_dim, dtype=torch.float32, requires_grad=True)))
                self.proj = nn.ParameterList(tmp)
            else:
                tmp = []
                tmp.append(nn.Parameter(torch.zeros(self.num_heads, self.low_dim, self.hidden_dim, dtype=torch.float32, requires_grad=True)))
                for _ in range(num_proj - 2):
                    tmp.append(nn.Parameter(torch.zeros(self.num_heads, self.hidden_dim, self.hidden_dim, dtype=torch.float32, requires_grad=True)))
                tmp.append(nn.Parameter(torch.zeros(self.num_heads, self.hidden_dim, self.head_dim, dtype=torch.float32, requires_grad=True)))
                self.proj = nn.ParameterList(tmp)

    def weight_init(self):
        for proj in self.proj:
            for i in range(self.num_heads):
                # torch.nn.init.orthogonal_(proj[i])
                proj.data[i] = torch.eye(proj.data[i].shape[0], proj.data[i].shape[1], dtype=torch.float32, requires_grad=True)

    def forward(self, x):
        for proj in self.proj:
            x = torch.einsum("bsnh,nhd->bsnd", x, proj)
        return x

