from typing import Iterable, Optional
from enum import Enum
import torch
import torch.nn.functional as F
from torch import Tensor, nn

import numpy as np
from .utils import LayerNorm, Linear, make_pad_mask


class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int, att: str):
        super().__init__()
        self.n_head = n_head
        self.att = att
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)

    def forward(
        self,
        x: Tensor,
        xa: Tensor = None,
        mask: Tensor = None,
    ):
        q = self.query(x)
        assert self.att in ["self", "cross"]
        if self.att == "self":
            k = self.key(x)
            v = self.value(x)
        else:
            k = self.key(xa)
            v = self.value(xa)
        wv = self.qkv_attention(q, k, v, mask)
        return self.out(wv)

    def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

        qk = q @ k
        if mask is not None:
            assert mask.dim() in [3, 4]
            if mask.dim() == 3:
                mask = mask.unsqueeze(1)
            qk += mask[:, :, -qk.shape[2] :]
        qk = qk.float()

        w = F.softmax(qk, dim=-1).to(q.dtype)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)


class SimulBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, n_enc_dim: int):
        super().__init__()

        n_mlp = n_state * 4

        self.attn = MultiHeadAttention(n_state, n_head, att="cross")
        self.attn_ln = LayerNorm(n_state)

        self.cross_attn = MultiHeadAttention(n_state, n_head, att="cross")
        self.oracle_expert = nn.Sequential(Linear(n_state + n_enc_dim, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
        self.cross_attn_ln = LayerNorm(n_state)

        self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
        self.mlp_ln = LayerNorm(n_state)

    def forward(
        self,
        x: Tensor,
        hidden: Tensor,
        xa: Optional[Tensor] = None,
        xa_lens: Optional[Tensor] = None,
        xa_oracle: Optional[Tensor] = None,
        route_score: Optional[Tensor] = None,
    ):

        bsz, n_ctx = x.shape[:2]
        dtype = next(self.parameters()).dtype
        causal_mask = torch.empty(n_ctx, n_ctx, device=x.device, dtype=dtype).fill_(-np.inf).triu_(1).unsqueeze(0)

        x = x + self.attn(self.attn_ln(x), xa=hidden, mask=causal_mask)

        res = x
        x = self.cross_attn_ln(x)

        if xa_oracle is None:
            xa_oracle = torch.zeros_like(x)
        else:
            xa_oracle = xa_oracle.unsqueeze(1).expand(-1, x.shape[1], -1)

        route_score = route_score.unsqueeze(-1)
        cross_att_mask = self.make_cross_mask(x, xa, xa_lens)
        x = (
            self.cross_attn(x, xa, mask=cross_att_mask) * (1 - route_score)
            + self.oracle_expert(torch.cat([x, xa_oracle], dim=-1)) * route_score
        )

        x = res + x
        x = x + self.mlp(self.mlp_ln(x))
        return x

    def make_cross_mask(self, x, xa, xa_lens):
        if xa_lens is not None:
            cross_mask = make_pad_mask(xa_lens, xa_lens.max()).unsqueeze(1).expand(-1, x.shape[1], -1)
            cross_mask = (
                torch.zeros_like(cross_mask, dtype=xa.dtype).masked_fill(cross_mask, float("-inf")).to(xa.device)
            )
        else:
            cross_mask = None
        return cross_mask


class SimulDecoder(nn.Module):
    def __init__(self, n_state: int, n_head: int, n_enc_dim: int, n_layer: int):
        super().__init__()
        self.n_layer = n_layer
        self.blocks = nn.ModuleList([SimulBlock(n_state, n_head, n_enc_dim) for _ in range(n_layer)])
        self.ln = LayerNorm(n_state)
        self.router = nn.Sequential(Linear(n_state, n_state // 2), nn.GELU(), Linear(n_state // 2, 1))

    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        xa_lens: Optional[Tensor] = None,
        global_xa: Optional[Tensor] = None,
        global_xa_lens: Optional[Tensor] = None,
        noise_scale: float = 1.0,
    ):
        hidden = x
        if self.training:
            route_score = self.router(hidden).squeeze(-1)  # bsz, n_ctx
            route_score = route_score + torch.randn_like(route_score) * noise_scale
            route_score = torch.sigmoid(route_score)
        else:
            route_score = torch.sigmoid(self.router(hidden).squeeze(-1))

        xa_oracle = torch.stack([global_xa[i, global_xa_lens[i] - 1] for i in range(global_xa.size(0))], dim=0)
        for i in range(self.n_layer):
            x = self.blocks[i](x, hidden, xa=xa, xa_lens=xa_lens, xa_oracle=xa_oracle, route_score=route_score)
        return self.ln(x), route_score