import math
import torch
from torch import nn
from torch.nn import init
import torch as T
import torch.nn.functional as F


class CellLayer(nn.Module):

    def __init__(self, hidden_dim, cell_hidden_dim, dropout):
        super(CellLayer, self).__init__()
        self.hidden_dim = hidden_dim
        self.wcell1 = nn.Linear(2 * hidden_dim, cell_hidden_dim)
        self.wcell2 = nn.Linear(cell_hidden_dim, 4 * hidden_dim)
        self.LN2 = nn.LayerNorm(hidden_dim)
        self.dropout = dropout

    def forward(self, l=None, r=None):
        N, B, S, D = l.size()
        concated = torch.cat([l, r], dim=-1)
        intermediate = F.gelu(self.wcell1(concated))
        intermediate = F.dropout(intermediate, p=self.dropout, training=self.training)
        contents = self.wcell2(intermediate)

        contents = contents.view(N, B, S, 4, D)
        gates = torch.sigmoid(contents[..., 0:3, :])
        parent = contents[..., 3, :]
        f1 = gates[..., 0, :]
        f2 = gates[..., 1, :]
        i = gates[..., 2, :]
        transition = self.LN2(f1 * l + f2 * r + i * parent)
        return transition


class BeamGumbelDisentangledTreeCell(nn.Module):
    def __init__(self, config):
        super(BeamGumbelDisentangledTreeCell, self).__init__()
        self.config = config
        self.word_dim = config["hidden_size"]
        self.hidden_dim = config["hidden_size"]
        self.rao = config["rao"]
        self.beam_size = config["beam_size"]
        self.diffop1 = config["diffop1"]
        self.diffop2 = config["diffop2"]
        self.temperature = config["temperature"] if "temperature" in config else 1

        self.word_linear = nn.Linear(in_features=self.word_dim,
                                     out_features=2 * self.hidden_dim)
        # self.compress_linear = nn.Linear(in_features=self.hidden_dim, out_features=self.small_d)

        self.treecell_layer = CellLayer(self.hidden_dim, 2 * self.hidden_dim, config["dropout"])
        self.controller_layer = CellLayer(self.hidden_dim, 2 * self.hidden_dim, config["dropout"])
        if self.config["conv_decision"]:
            self.decide_linear = nn.Linear(5 * self.hidden_dim, 1)
        else:
            self.decide_linear = nn.Linear(self.hidden_dim, 1)
        # self.comp_query = nn.Parameter(torch.FloatTensor(self.hidden_dim))
        self.LN1 = nn.LayerNorm(self.hidden_dim)
        self.LN2 = nn.LayerNorm(self.hidden_dim)

    @staticmethod
    def update_state(old_content_state, new_content_state, old_form_state, new_form_state, done_mask):
        N = old_content_state.size(0)
        done_mask = done_mask.view(N, 1, 1, 1)
        content_state = done_mask * new_content_state + (1 - done_mask) * old_content_state[..., :-1, :]
        form_state = done_mask * new_form_state + (1 - done_mask) * old_form_state[..., :-1, :]
        return content_state, form_state

    def masked_softmax(self, logits, mask=None, dim=-1):
        eps = 1e-20
        probs = F.softmax(logits, dim=dim)
        if mask is not None:
            mask = mask.float()
            probs = probs * mask + eps
            probs = probs / probs.sum(dim, keepdim=True)
        return probs

    @T.no_grad()
    def conditional_gumbel(self, logits, mask, D, k=10):
        eps = 1e-20
        N, sk, S = logits.size()
        assert mask.size() == (N, sk, S)
        # iid. exponential
        E = T.distributions.exponential.Exponential(rate=T.ones_like(logits)).sample([k])
        assert E.size() == (k, N, sk, S)

        logits = logits.unsqueeze(0)
        D = D.unsqueeze(0)
        mask = mask.unsqueeze(0)
        assert logits.size() == (1, N, sk, S)
        assert D.size() == (1, N, sk, S)
        # E of the chosen class
        Ei = (D * E).sum(dim=-1, keepdim=True)
        assert Ei.size() == (k, N, sk, 1)
        # partition function (normalization constant)
        Z = T.sum(mask * T.exp(logits), dim=-1, keepdim=True)
        assert Z.size() == (1, N, sk, 1)
        Z = Z + eps
        # Sampled gumbel-adjusted logits
        adjusted = (D * (-T.log(Ei + eps) + T.log(Z)) +
                    (1 - D) * -T.log(E / (T.exp(logits) + eps) + (Ei / Z) + eps))
        assert adjusted.size() == (k, N, sk, S)
        return adjusted - logits

    def st_gumbel_softmax(self, logits, select_k=1, temperature=1.0, mask=None, training=True):
        eps = 1e-20
        rao_k = self.config["rao_k"]
        N, S = logits.size()
        temperature = self.temperature

        if (self.config["stochastic"] and (self.training or self.config["test_time_stochastic"])) or training:
            u = logits.data.new(*logits.size()).uniform_()
            gumbel_noise = -torch.log(-torch.log(u + eps) + eps)
        else:
            gumbel_noise = 0

        y = logits + gumbel_noise

        y_ = self.masked_softmax(logits=y / temperature, mask=mask)
        y_argmax = T.topk(y_, dim=-1, k=select_k)[1]
        assert y_argmax.size() == (N, select_k)

        y_hard = F.one_hot(y_argmax, num_classes=S).float()
        assert y_hard.size() == (N, select_k, S)

        if not training:
            return y_hard, y_.unsqueeze(1).repeat(1, select_k, 1)
        else:
            if self.rao:
                logits = logits.unsqueeze(1).repeat(1, select_k, 1)
                mask = mask.unsqueeze(1).repeat(1, select_k, 1)
                assert logits.size() == (N, select_k, S)
                assert mask.size() == (N, select_k, S)

                y = logits.unsqueeze(0) + self.conditional_gumbel(logits, mask, y_hard, rao_k)
                assert y.size() == (rao_k, N, select_k, S)

                y = self.masked_softmax(logits=y / temperature, mask=mask.unsqueeze(0))
                assert y.size() == (rao_k, N, select_k, S)
                y = T.mean(y, dim=0)
                assert y.size() == (N, select_k, S)
            else:
                y = y_.unsqueeze(1).repeat(1, select_k, 1)
                assert y.size() == (N, select_k, S)

            assert y.size() == (N, select_k, S)
            assert y_hard.size() == (N, select_k, S)

            y_hard = (y_hard - y).detach() + y
            return y_hard, y

    def select_composition(self, old_form_state, new_form_state,
                           old_content_state, mask, accu_scores, beam_mask):

        N, B, S, D = new_form_state.size()
        assert accu_scores.size() == (N, B)
        assert mask.size() == (N, S)
        assert beam_mask.size() == (N, B)

        comp_weights = self.decide_linear(new_form_state).squeeze(-1)  # / math.sqrt(self.hidden_dim)

        topk = min(S, self.beam_size)  # beam_size
        training = self.training if self.diffop1 else False

        select_mask, soft_scores = self.st_gumbel_softmax(logits=comp_weights.view(N * B, S),
                                                          temperature=1,
                                                          mask=mask.view(N, 1, S).repeat(1, B, 1).view(N * B, S),
                                                          select_k=topk,
                                                          training=training)

        soft_scores = F.softmax(comp_weights, dim=-1)
        assert soft_scores.size() == (N, B, S)
        soft_scores = mask.unsqueeze(1) * soft_scores + 1e-20
        soft_scores = soft_scores / T.sum(soft_scores, dim=-1, keepdim=True)
        soft_scores = soft_scores.unsqueeze(2)

        assert select_mask.size() == (N * B, topk, S)
        select_mask = select_mask.view(N, B, topk, S)

        l = old_content_state[:, :, :-1, :]
        r = old_content_state[:, :, 1:, :]
        l = T.matmul(select_mask, l)
        r = T.matmul(select_mask, r)
        new_content_state = self.treecell_layer(l, r).unsqueeze(-2)
        assert new_content_state.size() == (N, B, topk, 1, D)

        assert new_form_state.size() == (N, B, S, D)
        new_form_state = T.matmul(select_mask, new_form_state).unsqueeze(-2)
        assert new_form_state.size() == (N, B, topk, 1, D)

        assert soft_scores.size() == (N, B, 1, S)
        # soft_scores = soft_scores.view(N, B, topk, S)
        new_scores = T.log(T.sum(select_mask * soft_scores, dim=-1) + 1e-20)
        assert new_scores.size() == (N, B, topk)

        done_mask = 1 - mask[:, 0].view(N, 1, 1).repeat(1, B, 1)
        if topk == 1:
            done_topk = T.ones(N, B, topk).float().to(mask.device)
        else:
            done_topk = T.cat([T.ones(N, B, 1).float().to(mask.device),
                               T.zeros(N, B, topk - 1).float().to(mask.device)], dim=-1)
        assert done_topk.size() == (N, B, topk)

        not_done_topk = T.ones(N, B, topk).float().to(mask.device)
        new_beam_mask = done_mask * done_topk + (1 - done_mask) * not_done_topk
        beam_mask = beam_mask.unsqueeze(-1) * new_beam_mask

        assert beam_mask.size() == (N, B, topk)
        beam_mask = beam_mask.view(N, B * topk)

        accu_scores = accu_scores.view(N, B, 1) + new_scores
        accu_scores = accu_scores.view(N, B * topk)
        # accu_scores = T.clip(accu_scores, min=-999999)

        select_mask = select_mask.view(N, B * topk, S)

        old_form_state = old_form_state.unsqueeze(2).repeat(1, 1, topk, 1, 1)
        assert old_form_state.size() == (N, B, topk, S + 1, D)
        old_content_state = old_content_state.unsqueeze(2).repeat(1, 1, topk, 1, 1)
        assert old_content_state.size() == (N, B, topk, S + 1, D)
        assert new_form_state.size() == (N, B, topk, 1, D)
        assert new_content_state.size() == (N, B, topk, 1, D)

        new_form_state = new_form_state.view(N, B * topk, 1, D)
        new_content_state = new_content_state.view(N, B * topk, 1, D)
        old_form_state = old_form_state.view(N, B * topk, S + 1, D)
        old_content_state = old_content_state.view(N, B * topk, S + 1, D)

        if (B * topk) > self.beam_size:
            B2 = self.beam_size
            assert accu_scores.size() == beam_mask.size()
            training = self.training if self.diffop2 else False
            beam_select_mask, _ = self.st_gumbel_softmax(logits=accu_scores,
                                                         temperature=1,
                                                         mask=beam_mask,
                                                         select_k=B2,
                                                         training=training)
            assert beam_select_mask.size() == (N, B2, B * topk)

            newstates = T.cat([new_form_state, new_content_state], dim=-1)
            assert newstates.size() == (N, B * topk, 1, 2 * D)
            oldstates = T.cat([old_form_state, old_content_state], dim=-1)
            assert oldstates.size() == (N, B * topk, S + 1, 2 * D)

            newstates = T.matmul(beam_select_mask, newstates.view(N, B * topk, -1))
            newstates = newstates.view(N, B2, 1, 2 * D)

            new_form_state = newstates[..., 0:D]
            new_content_state = newstates[..., D:]

            oldstates = T.matmul(beam_select_mask, oldstates.view(N, B * topk, -1))
            oldstates = oldstates.view(N, B2, S + 1, 2 * D)

            old_form_state = oldstates[..., 0:D]
            old_content_state = oldstates[..., D:]

            select_mask = T.matmul(beam_select_mask, select_mask)
            assert select_mask.size() == (N, B2, S)

            accu_scores = T.matmul(beam_select_mask, accu_scores.unsqueeze(-1)).squeeze(-1)
            assert accu_scores.size() == (N, B2)

            beam_mask = T.matmul(beam_select_mask, beam_mask.unsqueeze(-1)).squeeze(-1)
            assert beam_mask.size() == (N, B2)
        else:
            B2 = B * topk

        select_mask_expand = select_mask.unsqueeze(-1)
        select_mask_cumsum = select_mask.cumsum(-1)

        left_mask = 1 - select_mask_cumsum
        left_mask_expand = left_mask.unsqueeze(-1)

        right_mask = select_mask_cumsum - select_mask
        right_mask_expand = right_mask.unsqueeze(-1)

        olc, orc = old_content_state[..., :-1, :], old_content_state[..., 1:, :]
        olf, orf = old_form_state[..., :-1, :], old_form_state[..., 1:, :]

        assert select_mask_expand.size() == (N, B2, S, 1)
        assert left_mask_expand.size() == (N, B2, S, 1)
        assert right_mask_expand.size() == (N, B2, S, 1)
        assert new_content_state.size() == (N, B2, 1, D)
        assert olc.size() == (N, B2, S, D)
        assert orc.size() == (N, B2, S, D)

        new_content_state = (select_mask_expand * new_content_state
                             + left_mask_expand * olc
                             + right_mask_expand * orc)

        assert new_form_state.size() == (N, B2, 1, D)
        assert olf.size() == (N, B2, S, D)
        assert orf.size() == (N, B2, S, D)

        new_form_state = (select_mask_expand * new_form_state
                          + left_mask_expand * olf
                          + right_mask_expand * orf)

        return old_form_state, new_form_state, old_content_state, new_content_state, accu_scores, beam_mask

    def forward(self, input, input_mask):

        max_depth = input.size(1)
        length_mask = input_mask

        state = self.word_linear(input)
        form_state = self.LN1(state[..., 0:self.hidden_dim])
        content_state = self.LN2(state[..., self.hidden_dim:])
        N, S, D = form_state.size()
        B = 1
        form_state = form_state.unsqueeze(1)
        assert form_state.size() == (N, B, S, D)

        content_state = content_state.unsqueeze(1)
        assert content_state.size() == (N, B, S, D)

        accu_scores = T.zeros(N, B).float().to(state.device)
        beam_mask = T.ones(N, B).float().to(state.device)

        for i in range(max_depth - 1):
            S = content_state.size(-2)
            B = content_state.size(1)

            if i < max_depth - 2:
                """
                lc = self.compress_linear(l)
                rc = self.compress_linear(r)
                """
                old_content_state = content_state.clone()
                old_form_state = form_state.clone()
                l = form_state[:, :, :-1, :]
                r = form_state[:, :, 1:, :]
                assert l.size() == (N, B, S - 1, D)
                assert r.size() == (N, B, S - 1, D)
                new_form_state = self.controller_layer(l=l, r=r)
                # We don't need to greedily select the composition in the
                # last iteration, since it has only one option left.
                old_form_state, new_form_state, \
                old_content_state, new_content_state, \
                accu_scores, beam_mask = self.select_composition(old_content_state=old_content_state,
                                                                 old_form_state=old_form_state,
                                                                 new_form_state=new_form_state,
                                                                 mask=length_mask[:, i + 1:],
                                                                 accu_scores=accu_scores,
                                                                 beam_mask=beam_mask)
            else:
                old_form_state = form_state.clone()
                old_content_state = content_state.clone()
                l = content_state[:, :, :-1, :]
                r = content_state[:, :, 1:, :]
                assert l.size() == (N, B, S - 1, D)
                assert r.size() == (N, B, S - 1, D)
                new_content_state = self.treecell_layer(l=l, r=r)
                new_form_state = form_state[..., :-1, :]
            done_mask = length_mask[:, i + 1]
            content_state, form_state = self.update_state(old_content_state=old_content_state,
                                                          new_content_state=new_content_state,
                                                          old_form_state=old_form_state, new_form_state=new_form_state,
                                                          done_mask=done_mask)
        h = content_state
        sequence = input
        input_mask = input_mask.unsqueeze(-1)
        aux_loss = None

        N, B, S, D = h.size()
        assert S == 1
        h = h.squeeze(-2)
        assert h.size() == (N, B, D)
        assert accu_scores.size() == (N, B)
        assert beam_mask.size() == (N, B)
        normed_scores = F.softmax(beam_mask * accu_scores + (1 - beam_mask) * -999999, dim=-1)

        global_state = T.sum(normed_scores.unsqueeze(-1) * h, dim=1)
        assert global_state.size() == (N, self.config["hidden_size"])

        return {"sequence": sequence, "global_state": global_state, "input_mask": input_mask, "aux_loss": aux_loss}
