import torch
from torch import nn
from one_line_addition.data import token_ids


R = float("inf")


class Encoder(nn.Module):
    def __init__(self, from_dim, to_dim, dim_per_head, n_head):
        super().__init__()
        out_dim = dim_per_head * n_head
        self.out_dim = out_dim
        self.dim_per_head = dim_per_head
        self.n_head = n_head

        self.from_fc = nn.Linear(from_dim, out_dim)
        self.to_fc = nn.Linear(to_dim, 2 * out_dim)
        self.ctx_fc = nn.Linear(out_dim, from_dim)

        self.ffd_fc1 = nn.Linear(from_dim, 4 * from_dim)
        self.ffd_fc2 = nn.Linear(4 * from_dim, from_dim)
        self.activation = nn.GELU()
        self.ln1 = nn.LayerNorm(from_dim)
        self.ln2 = nn.LayerNorm(from_dim)

    def forward(self, x, y, mask):
        B, F, _ = x.shape
        _, T, _ = y.shape
        N, D = self.n_head, self.dim_per_head

        q = self.from_fc(x)    # B, F, C
        kv = self.to_fc(y)    # B, T, C
        k, v = torch.split(kv, [N * D, N * D], dim=-1)   # B, T, C
        q = q.view((B, F, N, D))
        k = k.view((B, T, N, D))
        v = v.view((B, T, N, D))
        score = torch.einsum("bfnd,btnd->bnft", q, k)   # B, N, F, T

        score = score / D ** 0.5
        if mask is not None:
            score = score - (1.0 - mask.float()[:, None, None, :]) * 1e8

        prob = torch.softmax(score, dim=-1)   # B, N, F, T
        context = torch.einsum("bnft,btnd->bfnd", prob, v)   # B, F, N, D

        context = context.reshape((B, F, N * D))
        x = self.ln1(x + self.ctx_fc(context))

        ffd = self.ffd_fc2(self.activation(self.ffd_fc1(x)))
        x = self.ln2(x + ffd)

        return x, prob


class RelEncoder(nn.Module):
    def __init__(self, from_dim, to_dim, dim_per_head, n_head):
        super().__init__()
        out_dim = dim_per_head * n_head
        self.out_dim = out_dim
        self.dim_per_head = dim_per_head
        self.n_head = n_head

        pos = torch.tensor([i / 50 for i in range(150)], dtype=torch.float)
        self.rpos = pos[:, None] - pos[None, :]  # (50, 50)
        self.rpos_fc = nn.Linear(1, out_dim)
        self.rpos_fc2 = nn.Linear(out_dim, out_dim)
        self.relu = nn.ReLU()

        self.from_fc = nn.Linear(from_dim, out_dim)
        self.to_fc = nn.Linear(to_dim, 2 * out_dim)
        self.ctx_fc = nn.Linear(out_dim, from_dim)

        self.ffd_fc1 = nn.Linear(from_dim, 4 * from_dim)
        self.ffd_fc2 = nn.Linear(4 * from_dim, from_dim)
        self.activation = nn.GELU()
        self.ln1 = nn.LayerNorm(from_dim)
        self.ln2 = nn.LayerNorm(from_dim)

    def forward(self, x, y, mask):
        B, F, _ = x.shape
        _, T, _ = y.shape
        N, D = self.n_head, self.dim_per_head

        q = self.from_fc(x)    # B, F, C
        kv = self.to_fc(y)    # B, T, C
        k, v = torch.split(kv, [N * D, N * D], dim=-1)   # B, T, C
        q = q.view((B, F, N, D))
        k = k.view((B, T, N, D))
        v = v.view((B, T, N, D))
        score = torch.einsum("bfnd,btnd->bnft", q, k)   # B, N, F, T

        rpos = self.rpos_fc(self.rpos[:F, :T, None])    # F, T, C
        rpos = self.rpos_fc2(self.relu(rpos))
        rpos = rpos.reshape((F, T, N, D))  # F, T, N, D
        score_rpos = torch.einsum("bfnd,ftnd->bnft", q, rpos)
        score = score + score_rpos

        score = score / D ** 0.5
        if mask is not None:
            score = score - (1.0 - mask.float()[:, None, None, :]) * 1e8
        # four_R_plus_1_mask = (self.rpos[:F, :T] < 1.5 / 150).float() * (self.rpos[:F, :T] > -1.5 / 150).float()   # avoid float error
        # print(four_R_plus_1_mask)
        # score = score - (1.0 - four_R_plus_1_mask[None, None, :, :]) * 1e8

        prob = torch.softmax(score, dim=-1)   # B, N, F, T
        context = torch.einsum("bnft,btnd->bfnd", prob, v)   # B, F, N, D

        context = context.reshape((B, F, N * D))
        x = self.ln1(x + self.ctx_fc(context))

        ffd = self.ffd_fc2(self.activation(self.ffd_fc1(x)))
        x = self.ln2(x + ffd)

        return x, prob


class ParseModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()

        self._vocab_size = vocab_size

        self.embed = nn.Linear(vocab_size, 32, bias=False)
        self.pos = torch.tensor([i / 50 for i in range(150)], dtype=torch.float)  # (50,)
        n = 33
        self.enc1 = RelEncoder(n, n, 32, 4)
        self.enc2 = RelEncoder(n, n, 32, 4)
        self.enc3 = RelEncoder(n, n, 32, 4)
        self.out = nn.Linear(n, len(token_ids))

    def forward(self, x, mask=None, y=None):
        h = self.embed(x)   # B, T, C
        B, T, C = h.shape
        h = torch.cat([h, torch.tile(self.pos[None, :T, None], (B, 1, 1))], dim=-1)
        h, _ = self.enc1(h, h, mask)
        h, _ = self.enc2(h, h, mask)
        h, _ = self.enc3(h, h, mask)
        out = self.out(h)   # B, T, n
        dist = torch.distributions.Categorical(logits=out)

        loss = 0.0
        if y is not None:
            loss += -dist.log_prob(y) * mask
        return out, dist, loss
