import numpy as np
import torch as th
import torch.nn as nn

from .layers import clones


class MultiHeadAttention(nn.Module):
    "Multi-Head Attention"

    def __init__(self, h, dim_model):
        "h: number of heads; dim_model: hidden dimension"
        super(MultiHeadAttention, self).__init__()
        self.d_k = dim_model // h
        self.h = h
        # W_q, W_k, W_v, W_o
        self.linears = clones(nn.Linear(dim_model, dim_model, bias=False), 4)

    def get(self, x, fields="qkv"):
        "Return a dict of queries / keys / values."
        batch_size = x.shape[0]
        ret = {}
        if "q" in fields:
            ret["q"] = self.linears[0](x).view(batch_size, self.h, self.d_k)
        if "k" in fields:
            ret["k"] = self.linears[1](x).view(batch_size, self.h, self.d_k)
        if "v" in fields:
            ret["v"] = self.linears[2](x).view(batch_size, self.h, self.d_k)
        return ret

    def get_o(self, x):
        "get output of the multi-head attention"
        batch_size = x.shape[0]
        return self.linears[3](x.view(batch_size, -1))
