import torch
from torch import nn
from division.data import token_ids


class Encoder(nn.Module):
    def __init__(self, from_dim, to_dim, dim_per_head, n_head):
        super().__init__()
        out_dim = dim_per_head * n_head
        self.out_dim = out_dim
        self.dim_per_head = dim_per_head
        self.n_head = n_head

        self.from_fc = nn.Linear(from_dim, out_dim)
        self.to_fc = nn.Linear(to_dim, 2 * out_dim)
        self.ctx_fc = nn.Linear(out_dim, from_dim)

        self.ffd_fc1 = nn.Linear(from_dim, 4 * from_dim)
        self.ffd_fc2 = nn.Linear(4 * from_dim, from_dim)
        self.activation = nn.GELU()
        self.ln1 = nn.LayerNorm(from_dim)
        self.ln2 = nn.LayerNorm(from_dim)

    def forward(self, x, y, mask):
        B, F, _ = x.shape
        _, T, _ = y.shape
        N, D = self.n_head, self.dim_per_head

        q = self.from_fc(x)    # B, F, C
        kv = self.to_fc(y)    # B, T, C
        k, v = torch.split(kv, [N * D, N * D], dim=-1)   # B, T, C
        q = q.view((B, F, N, D))
        k = k.view((B, T, N, D))
        v = v.view((B, T, N, D))
        score = torch.einsum("bfnd,btnd->bnft", q, k)   # B, N, F, T

        score = score / D ** 0.5
        if mask is not None:
            score = score - (1.0 - mask.float()[:, None, None, :]) * 1e8

        prob = torch.softmax(score, dim=-1)   # B, N, F, T
        context = torch.einsum("bnft,btnd->bfnd", prob, v)   # B, F, N, D

        context = context.reshape((B, F, N * D))
        x = self.ln1(x + self.ctx_fc(context))

        ffd = self.ffd_fc2(self.activation(self.ffd_fc1(x)))
        x = self.ln2(x + ffd)

        return x, prob


class RelEncoder(nn.Module):
    def __init__(self, from_dim, to_dim, dim_per_head, n_head):
        super().__init__()
        out_dim = dim_per_head * n_head
        self.out_dim = out_dim
        self.dim_per_head = dim_per_head
        self.n_head = n_head

        self.rpos_fc = nn.Linear(1, out_dim)
        self.rpos_fc2 = nn.Linear(out_dim, out_dim)
        self.relu = nn.ReLU()

        self.from_fc = nn.Linear(from_dim, out_dim)
        self.to_fc = nn.Linear(to_dim, 2 * out_dim)
        self.ctx_fc = nn.Linear(out_dim, from_dim)

        self.ffd_fc1 = nn.Linear(from_dim, 4 * from_dim)
        self.ffd_fc2 = nn.Linear(4 * from_dim, from_dim)
        self.activation = nn.GELU()
        self.ln1 = nn.LayerNorm(from_dim)
        self.ln2 = nn.LayerNorm(from_dim)

    def forward(self, x, y, mask):
        B, F, _ = x.shape
        _, T, _ = y.shape
        N, D = self.n_head, self.dim_per_head

        q = self.from_fc(x)    # B, F, C
        kv = self.to_fc(y)    # B, T, C
        k, v = torch.split(kv, [N * D, N * D], dim=-1)   # B, T, C
        q = q.view((B, F, N, D))
        k = k.view((B, T, N, D))
        v = v.view((B, T, N, D))
        score = torch.einsum("bfnd,btnd->bnft", q, k)   # B, N, F, T

        orig_rpos = torch.arange(0, F)[:, None] - torch.arange(0, T)[None, :]  # F, T
        rpos = self.rpos_fc2(self.relu(self.rpos_fc(0.02 * orig_rpos[:, :, None])))   # F, T, C
        rpos = rpos.reshape((F, T, N, D))  # F, T, N, D
        score_rpos = torch.einsum("bfnd,ftnd->bnft", q, rpos)
        score = score + score_rpos

        score = score / D ** 0.5
        if mask is not None:
            score = score - (1.0 - mask.float()[:, None, None, :]) * 1e8
        r_mask = (orig_rpos < 1.5).float() * (orig_rpos > -1.5).float()   # avoid float error
        score = score - (1.0 - r_mask[None, None, :, :]) * 1e8

        prob = torch.softmax(score, dim=-1)   # B, N, F, T
        context = torch.einsum("bnft,btnd->bfnd", prob, v)   # B, F, N, D

        context = context.reshape((B, F, N * D))
        x = self.ln1(x + self.ctx_fc(context))

        ffd = self.ffd_fc2(self.activation(self.ffd_fc1(x)))
        x = self.ln2(x + ffd)

        return x, prob


class ParseModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()

        self._vocab_size = vocab_size

        self.embed = nn.Linear(vocab_size, 16, bias=False)
        self.fc = nn.Linear(16 * 12 + 0, 64)
        self.fc1 = nn.Linear(128, 64)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(128, 64)
        self.rel_all_enc1 = RelEncoder(64, 64, 32, 4)
        self.rel_all_enc2 = RelEncoder(64, 64, 32, 4)
        self.rel_all_enc3 = RelEncoder(64, 64, 32, 4)
        self.rel_enc1 = RelEncoder(64, 64, 32, 4)
        self.rel_enc2 = RelEncoder(64, 64, 32, 4)
        self.rel_enc3 = RelEncoder(64, 64, 32, 4)
        self.all_rel_enc1 = Encoder(64, 64, 32, 4)
        self.all_rel_enc2 = Encoder(64, 64, 32, 4)
        self.all_rel_enc3 = Encoder(64, 64, 32, 4)
        self.all_enc1 = Encoder(64, 64, 32, 4)
        self.all_enc2 = Encoder(64, 64, 32, 4)
        self.all_enc3 = Encoder(64, 64, 32, 4)
        self.out0 = nn.Linear(64, len(token_ids))
        self.out1 = nn.Linear(64, len(token_ids))
        self.out2 = nn.Linear(64, len(token_ids))
        self.out3 = nn.Linear(64, len(token_ids))
        self.out4 = nn.Linear(64, len(token_ids))
        self.out5 = nn.Linear(64, len(token_ids))
        self.out6 = nn.Linear(64, len(token_ids))
        self.out7 = nn.Linear(64, len(token_ids))

    def forward(self,
                input0, input1, input2, input3, input4, input5, input6, input7, input8, input9, input10, input11,
                mask=None,
                output0=None, output1=None, output2=None, output3=None,
                output4=None, output5=None, output6=None, output7=None):
        input0 = self.embed(input0)
        input1 = self.embed(input1)
        input2 = self.embed(input2)
        input3 = self.embed(input3)
        input4 = self.embed(input4)
        input5 = self.embed(input5)
        input6 = self.embed(input6)
        input7 = self.embed(input7)
        input8 = self.embed(input8)
        input9 = self.embed(input9)
        input10 = self.embed(input10)
        input11 = self.embed(input11)
        h = torch.cat([
            input0, input1, input2, input3, input4, input5, input6, input7,
            input8, input9, input10, input11], dim=-1)
        h = self.fc(h)

        oh = h
        rah, _ = self.rel_all_enc1(oh, oh, mask)
        ah, _ = self.all_enc1(h, rah, mask)
        arh, _ = self.all_rel_enc1(oh, oh, mask)
        rh, _ = self.rel_enc1(h, arh, mask)
        h = self.fc1(torch.cat([ah, rh], dim=-1))

        # oh = h
        rah, _ = self.rel_all_enc2(oh, oh, mask)
        ah, _ = self.all_enc2(h, rah, mask)
        arh, _ = self.all_rel_enc2(oh, oh, mask)
        rh, _ = self.rel_enc2(h, arh, mask)
        h = self.fc2(torch.cat([ah, rh], dim=-1))

        # oh = h
        rah, _ = self.rel_all_enc3(oh, oh, mask)
        ah, _ = self.all_enc3(h, rah, mask)
        arh, _ = self.all_rel_enc3(oh, oh, mask)
        rh, _ = self.rel_enc3(h, arh, mask)
        h = self.fc3(torch.cat([ah, rh], dim=-1))

        out0 = self.out0(h)   # B, T, n
        dist0 = torch.distributions.Categorical(logits=out0)
        out1 = self.out1(h)   # B, T, n
        dist1 = torch.distributions.Categorical(logits=out1)
        out2 = self.out2(h)   # B, T, n
        dist2 = torch.distributions.Categorical(logits=out2)
        out3 = self.out3(h)   # B, T, n
        dist3 = torch.distributions.Categorical(logits=out3)
        out4 = self.out4(h)   # B, T, n
        dist4 = torch.distributions.Categorical(logits=out4)
        out5 = self.out5(h)   # B, T, n
        dist5 = torch.distributions.Categorical(logits=out5)
        out6 = self.out6(h)   # B, T, n
        dist6 = torch.distributions.Categorical(logits=out6)
        out7 = self.out7(h)   # B, T, n
        dist7 = torch.distributions.Categorical(logits=out7)

        loss = 0.0
        if output0 is not None:
            loss += -dist0.log_prob(output0) * mask
        if output1 is not None:
            loss += -dist1.log_prob(output1) * mask
        if output2 is not None:
            loss += -dist2.log_prob(output2) * mask
        if output3 is not None:
            loss += -dist3.log_prob(output3) * mask
        if output4 is not None:
            loss += -dist4.log_prob(output4) * mask
        if output5 is not None:
            loss += -dist5.log_prob(output5) * mask
        if output6 is not None:
            loss += -dist6.log_prob(output6) * mask
        if output7 is not None:
            loss += -dist7.log_prob(output7) * mask
        return (out0, out1, out2, out3, out4, out5, out6, out7,
                dist0, dist1, dist2, dist3, dist4, dist5, dist6, dist7, loss)
