import json
import os
from functools import partial
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
from einops._torch_specific import allow_ops_in_compiled_graph
from ml_dtypes import bfloat16
from torch import nn
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

allow_ops_in_compiled_graph()
flex_attention = torch.compile(flex_attention)


class MaskedAttention(nn.Module):
    def __init__(self, d_model, num_heads, use_qk_norm: bool = False):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = d_model // self.num_heads
        self.use_qk_norm = use_qk_norm
        if self.use_qk_norm:
            self.q_norm = nn.RMSNorm(self.head_dim)
            self.k_norm = nn.RMSNorm(self.head_dim)

        self.wq = nn.Linear(d_model, d_model, bias=False)
        self.wk = nn.Linear(d_model, d_model, bias=False)
        self.wv = nn.Linear(d_model, d_model, bias=False)
        self.wo = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, block_mask):
        q = self.wq(x)
        k = self.wk(x)
        v = self.wv(x)

        q = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
        k = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
        v = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)

        if self.use_qk_norm:
            q = self.q_norm(q)
            k = self.k_norm(k)

        if block_mask is None:
            with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
                x = F.scaled_dot_product_attention(q, k, v)
        else:
            x = flex_attention(q, k, v, block_mask=block_mask)

        x = rearrange(x, "b h s d -> b s (h d)")
        x = self.wo(x)
        return x


class SparseAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads

        self.wq = nn.Linear(d_model, d_model, bias=False)
        self.wk = nn.Linear(d_model, d_model, bias=False)
        self.wv = nn.Linear(d_model, d_model, bias=False)
        self.wo = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, attn_idx, attn_mask=None):
        # x: (B, S, D)
        # attn_idx: (B, S, K) indices of keys to attend to
        # attn_mask: (B, S, K) boolean mask, True for valid positions
        q = rearrange(self.wq(x), "b s (h d) -> b h s d", h=self.num_heads)
        k = rearrange(self.wk(x), "b s (h d) -> b h s d", h=self.num_heads)
        v = rearrange(self.wv(x), "b s (h d) -> b h s d", h=self.num_heads)

        # gather only relevant keys/values
        B, H, S, D = q.shape
        _, _, K = attn_idx.shape

        # Expand attn_idx to include head dimension: (B, H, S, K)
        attn_idx_exp = attn_idx[:, None, :, :].expand(B, H, S, K)

        # Use advanced indexing
        # Create batch and head index grids
        b_idx = torch.arange(B, device=k.device).view(B, 1, 1, 1).expand(B, H, S, K)
        h_idx = torch.arange(H, device=k.device).view(1, H, 1, 1).expand(B, H, S, K)

        # Index into k and v: k[b_idx, h_idx, attn_idx_exp, :] gives (B, H, S, K, D)
        k_sel = k[b_idx, h_idx, attn_idx_exp, :]
        v_sel = v[b_idx, h_idx, attn_idx_exp, :]

        # scaled dot-product attention but over K keys
        attn_scores = torch.einsum("bhsd,bhskd->bhsk", q, k_sel) / (D**0.5)

        # Expand mask for heads: (B, S, K) -> (B, H, S, K)
        attn_mask_exp = attn_mask[:, None, :, :].expand(B, H, S, K)
        attn_scores = attn_scores.masked_fill(~attn_mask_exp, float("-inf"))

        attn_probs = attn_scores.softmax(dim=-1)
        out = torch.einsum("bhsk,bhskd->bhsd", attn_probs, v_sel)

        out = rearrange(out, "b h s d -> b s (h d)")
        return self.wo(out)


class FFN(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()

        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class RelationalBlock(nn.Module):
    def __init__(
        self,
        d_model,
        num_heads,
        d_ff,
        use_sw_attn=False,
        use_qk_norm=False
    ):
        super().__init__()
        self.use_sw_attn = use_sw_attn

        # Define attention types based on whether sw is enabled
        self.attn_types = ["col", "feat", "nbr"]
        if use_sw_attn:
            self.attn_types.append("sw")

        self.norms = nn.ModuleDict(
            {l: nn.RMSNorm(d_model) for l in self.attn_types + ["ffn"]}
        )

        self.attns = nn.ModuleDict()

        # Use MaskedAttention for other attention types
        for l in self.attn_types:
            self.attns[l] = MaskedAttention(d_model, num_heads, use_qk_norm)

        self.ffn = FFN(d_model, d_ff)

    def forward(self, x, block_masks, feat_attn_idx, feat_attn_mask):
        for attn in self.attn_types:
            x = x + self.attns[attn](self.norms[attn](x), block_mask=block_masks[attn])

        x = x + self.ffn(self.norms["ffn"](x))
        return x


def _make_block_mask(mask, batch_size, seq_len, device):
    def _mod(b, h, q_idx, kv_idx):
        return mask[b, q_idx, kv_idx]

    return create_block_mask(
        mask_mod=_mod,
        B=batch_size,
        H=None,
        Q_LEN=seq_len,
        KV_LEN=seq_len,
        device=device,
        _compile=True,
    )


class RelationalTransformer(nn.Module):
    def __init__(
        self,
        num_blocks,
        d_model,
        d_text,
        num_heads,
        d_ff,
        use_temporal_mask=False,
        use_sw_attn=False,
        sw_len=None,
        use_qk_norm=False
    ):
        super().__init__()
        self.use_temporal_mask = use_temporal_mask
        self.use_sw_attn = use_sw_attn
        self.sw_len = sw_len

        self.enc_dict = nn.ModuleDict(
            {
                "number": nn.Linear(1, d_model, bias=True),
                "text": nn.Linear(d_text, d_model, bias=True),
                "datetime": nn.Linear(1, d_model, bias=True),
                "col_name": nn.Linear(d_text, d_model, bias=True),
                "boolean": nn.Linear(1, d_model, bias=True),
            }
        )
        self.dec_dict = nn.ModuleDict(
            {
                "number": nn.Linear(d_model, 1, bias=True),
                "text": nn.Linear(d_model, d_text, bias=True),
                "datetime": nn.Linear(d_model, 1, bias=True),
                "boolean": nn.Linear(d_model, 1, bias=True),
            }
        )
        self.norm_dict = nn.ModuleDict(
            {
                "number": nn.RMSNorm(d_model),
                "text": nn.RMSNorm(d_model),
                "datetime": nn.RMSNorm(d_model),
                "col_name": nn.RMSNorm(d_model),
                "boolean": nn.RMSNorm(d_model),
            }
        )
        self.mask_embs = nn.ParameterDict(
            {
                t: nn.Parameter(torch.randn(d_model))
                for t in ["number", "text", "datetime", "boolean"]
            }
        )
        self.blocks = nn.ModuleList(
            [
                RelationalBlock(d_model, num_heads, d_ff, use_sw_attn=use_sw_attn, use_qk_norm=use_qk_norm)
                for i in range(num_blocks)
            ]
        )
        self.norm_out = nn.RMSNorm(d_model)
        self.d_model = d_model

    def forward(self, batch):
        node_idxs = batch["node_idxs"]
        f2p_nbr_idxs = batch["f2p_nbr_idxs"]
        col_name_idxs = batch["col_name_idxs"]
        table_name_idxs = batch["table_name_idxs"]
        is_padding = batch["is_padding"]
        batch_size, seq_len = node_idxs.shape

        batch_size, seq_len = node_idxs.shape
        device = node_idxs.device

        # Padding mask for attention pairs (allow only non-pad -> non-pad)
        pad = (~is_padding[:, :, None]) & (~is_padding[:, None, :])  # (B, S, S)

        # Temporal mask: only allow attention to cells with timestamp <= query timestamp
        if self.use_temporal_mask and "timestamps" in batch:
            timestamps = batch["timestamps"]  # (B, S)
            # kv timestamp must be <= query timestamp (broadcasted comparison)
            temporal_mask = (
                timestamps[:, None, :] <= timestamps[:, :, None]
            )  # (B, S, S)
            pad = pad & temporal_mask

        # cells in the same node
        same_node = node_idxs[:, :, None] == node_idxs[:, None, :]  # (B, S, S)

        # kv index is among q's foreign -> primary neighbors
        kv_in_f2p = (node_idxs[:, None, :, None] == f2p_nbr_idxs[:, :, None, :]).any(
            -1
        )  # (B, S, S)

        # q index is among kv's primary -> foreign neighbors (reverse relation)
        q_in_f2p = (node_idxs[:, :, None, None] == f2p_nbr_idxs[:, None, :, :]).any(
            -1
        )  # (B, S, S)

        # Same column AND same table
        same_col_table = (col_name_idxs[:, :, None] == col_name_idxs[:, None, :]) & (
            table_name_idxs[:, :, None] == table_name_idxs[:, None, :]
        )  # (B, S, S)

        # Build attention masks
        attn_masks = {
            "feat": (same_node | kv_in_f2p) & pad,
            "nbr": q_in_f2p & pad,
            "col": same_col_table & pad,
        }

        # Sliding window attention: token attends to previous k tokens based on timestamp ranks
        if self.use_sw_attn:
            if self.sw_len is None:
                raise ValueError("sw_len must be specified when use_sw_attn=True")
            if "timestamps" not in batch:
                raise ValueError("timestamps must be in batch when use_sw_attn=True")

            timestamps = batch["timestamps"]  # (B, S)
            rank_diff = timestamps[:, :, None] - timestamps[:, None, :]  # (B, S, S)
            sw_mask = (rank_diff >= 0) & (rank_diff <= self.sw_len)  # (B, S, S)
            attn_masks["sw"] = sw_mask & pad

        # Make them contiguous for better kernel performance
        for l in attn_masks:
            attn_masks[l] = attn_masks[l].contiguous()

        # Convert to block masks
        make_block_mask = partial(
            _make_block_mask,
            batch_size=batch_size,
            seq_len=seq_len,
            device=device,
        )
        block_masks = {
            l: make_block_mask(attn_mask) for l, attn_mask in attn_masks.items()
        }

        x = 0
        x = x + (
            self.norm_dict["col_name"](
                self.enc_dict["col_name"](batch["col_name_values"])
            )
            * (~is_padding)[..., None]
        )

        for i, t in enumerate(["number", "text", "datetime", "boolean"]):
            x = x + (
                self.norm_dict[t](self.enc_dict[t](batch[t + "_values"]))
                * ((batch["sem_types"] == i) & ~batch["masks"] & ~is_padding)[..., None]
            )
            x = x + (
                self.mask_embs[t]
                * ((batch["sem_types"] == i) & batch["masks"] & ~is_padding)[..., None]
            )

        # Extract feat attention indices from batch
        feat_attn_idx = (
            batch["feat_attn_idx"].view(batch_size, seq_len, -1).long()
        )  # (B, S, K)
        feat_attn_mask = batch["feat_attn_mask"].view(
            batch_size, seq_len, -1
        )  # (B, S, K)

        for i, block in enumerate(self.blocks):
            x = block(x, block_masks, feat_attn_idx, feat_attn_mask)

        x = self.norm_out(x)

        yhat_out = {"number": None, "text": None, "datetime": None, "boolean": None}

        B, S, _ = x.shape
        sem_types = batch["sem_types"]  # (B,S) ints 0..3
        masks = batch["masks"].bool()  # (B,S) where to train

        loss_per_seq = x.new_zeros(B)

        for i, t in enumerate(["number", "text", "datetime", "boolean"]):
            yhat = self.dec_dict[t](x)  # (B,S, D_t)
            y = batch[f"{t}_values"]  # (B,S, D_y)
            sem_type_mask = (sem_types == i) & masks  # (B,S) mask for this type

            if not sem_type_mask.any():
                if t in yhat_out:
                    # still touch the param to avoid unused param error
                    loss_per_seq = loss_per_seq + (yhat.sum() * 0.0)
                    yhat_out[t] = yhat
                continue

            if t in ("number", "datetime"):
                loss_t = F.huber_loss(yhat, y, reduction="none").mean(-1)  # (B, S)
            elif t == "boolean":
                loss_t = F.binary_cross_entropy_with_logits(
                    yhat, (y > 0).float(), reduction="none"
                ).mean(
                    -1
                )  # (B, S)
            elif t == "text":
                raise ValueError("masking text not supported")

            # Sum loss per sequence for this type
            loss_per_seq = loss_per_seq + (loss_t * sem_type_mask).sum(dim=1)  # (B,)

            if t in yhat_out:
                yhat_out[t] = yhat

        # Normalize by number of masks per sequence, then average across sequences
        masks_per_seq = masks.sum(dim=1).float()  # (B,)
        loss_per_seq = loss_per_seq / masks_per_seq  # (B,)
        loss_out = loss_per_seq.mean()  # scalar

        return loss_out, yhat_out
