from __future__ import annotations
from typing import List, Tuple
import torch.nn as nn
from ..tools.spec import GateSpec, GateMatrixSpec
from ..tools.masked import MaskedFeedForward, MaskedLinear


def _find_transformer_layers(model: nn.Module) -> List[nn.Module]:
    if hasattr(model, "trm_encoder") and hasattr(model.trm_encoder, "layer"):
        layers = list(model.trm_encoder.layer)
        if not layers:
            raise RuntimeError("[BERT4Rec][RecBole] encoder.layer is empty")
        for i, block in enumerate(layers):
            ok = (
                hasattr(block, "multi_head_attention")
                and hasattr(block.multi_head_attention, "dense")
                and hasattr(block, "feed_forward")
                and hasattr(block.feed_forward, "dense_1")
                and hasattr(block.feed_forward, "dense_2")
            )
            if not ok:
                raise RuntimeError(
                    f"[BERT4Rec][RecBole] layer[{i}] missing expected submodules"
                    "(multi_head_attention.dense / feed_forward.dense_1 / feed_forward.dense_2)"
                )
        return layers
    raise RuntimeError("[BERT4Rec][RecBole] Expected model.trm_encoder.layer")


def _get_attn_out_linear_and_parent(
    layer: nn.Module,
) -> Tuple[nn.Linear, nn.Module, str]:
    att = layer.multi_head_attention
    lin = att.dense
    return lin, att, "dense"


class BERT4RecBackend:
    def build_ffn_gate_spec(self, model: nn.Module) -> GateSpec:
        layers = _find_transformer_layers(model)
        mats = []
        cursor = 0
        for li, layer in enumerate(layers):
            ffn = layer.feed_forward
            w1 = ffn.dense_1.weight
            alpha, beta = w1.shape
            r_start = cursor
            r_len = alpha
            cursor += r_len
            c_start = cursor
            c_len = beta
            cursor += c_len
            mats.append(
                GateMatrixSpec(
                    key=f"layer{li}.ffn.dense1",
                    shape=w1.shape,
                    layer_idx=li,
                    kind="dense1",
                    r_start=r_start,
                    r_len=r_len,
                    c_start=c_start,
                    c_len=c_len,
                )
            )
            w2 = ffn.dense_2.weight
            alpha2, beta2 = w2.shape
            r_start = cursor
            r_len = alpha2
            cursor += r_len
            c_start = cursor
            c_len = beta2
            cursor += c_len
            mats.append(
                GateMatrixSpec(
                    key=f"layer{li}.ffn.dense2",
                    shape=w2.shape,
                    layer_idx=li,
                    kind="dense2",
                    r_start=r_start,
                    r_len=r_len,
                    c_start=c_start,
                    c_len=c_len,
                )
            )
        print(
            f"[GateSpec][BERT4Rec-RecBole][FFN] layers={len(layers)} mats={len(mats)} total={cursor}"
        )
        return GateSpec(mats=mats, total_gates=cursor)

    def build_attn_gate_spec(self, model: nn.Module) -> GateSpec:
        layers = _find_transformer_layers(model)
        mats = []
        cursor = 0
        for li, layer in enumerate(layers):
            lin, _, _ = _get_attn_out_linear_and_parent(layer)
            w = lin.weight
            alpha, beta = w.shape
            r_start = cursor
            r_len = alpha
            cursor += r_len
            c_start = cursor
            c_len = beta
            cursor += c_len
            mats.append(
                GateMatrixSpec(
                    key=f"layer{li}.attn.W_O",
                    shape=w.shape,
                    layer_idx=li,
                    kind="attnO",
                    r_start=r_start,
                    r_len=r_len,
                    c_start=c_start,
                    c_len=c_len,
                )
            )
        print(
            f"[GateSpec][BERT4Rec-RecBole][Attn] layers={len(layers)} mats={len(mats)} total={cursor}"
        )
        return GateSpec(mats=mats, total_gates=cursor)

    def patch_with_maskable_ffn(self, model: nn.Module):
        layers = _find_transformer_layers(model)
        wrappers: List[MaskedFeedForward] = []
        for li, layer in enumerate(layers):
            wrapper = MaskedFeedForward(layer.feed_forward)
            layer.feed_forward = wrapper
            wrappers.append(wrapper)
        print(f"[Patch][BERT4Rec-RecBole] MaskedFeedForward x{len(wrappers)}")
        return wrappers

    def patch_with_maskable_attn_out(self, model: nn.Module):
        layers = _find_transformer_layers(model)
        wrappers: List[MaskedLinear] = []
        for li, layer in enumerate(layers):
            lin, parent, name = _get_attn_out_linear_and_parent(layer)
            wrapper = MaskedLinear(lin)
            setattr(parent, name, wrapper)
            wrappers.append(wrapper)
        print(f"[Patch][BERT4Rec-RecBole] MaskedLinear(attn dense) x{len(wrappers)}")
        return wrappers
