import torch
import numpy as np
import re
import math

import torch.nn as nn
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import (
    LlamaRotaryEmbedding,
)
from typing import Optional, List, Tuple, Union
from dataclasses import dataclass
from einops import rearrange, repeat
from torch_geometric.utils import to_dense_adj

def profile(func):
    from line_profiler import LineProfiler
    def wrapper(*args, **kwargs):
        lp = LineProfiler()
        lp_wrapper = lp(func)
        result = lp_wrapper(*args, **kwargs)
        lp.print_stats()

        return result
    return wrapper


def compute_nrmse(y_true, y_pred, norm="mean", eps=1e-8, percentage=False):
    """
    Compute Normalized Root Mean Square Error (NRMSE)

    Parameters
    ----------
    y_true : array-like, shape (N, ...)
        Ground truth values.
    y_pred : array-like, shape (N, ...)
        Predicted values.
    norm : str, default="mean"
        Normalization method:
        - "mean": divide by mean of y_true
        - "range": divide by (max - min) of y_true
        - "std": divide by std of y_true
    eps : float, default=1e-8
        Small constant to avoid division by zero.
    percentage : bool, default=False
        If True, return value ×100 (as %).

    Returns
    -------
    nrmse : float
        Normalized RMSE (scalar).
    """
    y_true = np.asarray(y_true).ravel()
    y_pred = np.asarray(y_pred).ravel()

    rmse = np.sqrt(np.mean((y_true - y_pred) ** 2))

    if norm == "mean":
        denom = np.mean(np.abs(y_true)) + eps
    elif norm == "range":
        denom = (np.max(y_true) - np.min(y_true)) + eps
    elif norm == "std":
        denom = np.std(y_true) + eps
    else:
        raise ValueError("norm must be one of ['mean', 'range', 'std']")

    value = rmse / denom
    if percentage:
        value *= 100.0
    return value


def compute_r2_score(y_true, y_pred, eps=1e-8):
    """
    Coefficient of Determination (R^2)
    """
    y_true = np.asarray(y_true).ravel()
    y_pred = np.asarray(y_pred).ravel()

    ss_res = np.sum((y_true - y_pred) ** 2)
    ss_tot = np.sum((y_true - np.mean(y_true)) ** 2) + eps

    return 1 - ss_res / ss_tot


@dataclass
class RegressionOutput:
    loss: Optional[torch.Tensor] = None
    preds: Optional[torch.Tensor] = None

class GLADataset(torch.utils.data.Dataset):
    def __init__(self, samples):
        # samples: List of tuples (graph_data, language_input_str, ground_truth_response_str)
        self.samples = samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


def float_discretization(input, floats_max, bins=2048):
    return (np.floor((input / floats_max) * (bins // 2)) + (bins // 2)).astype(int).tolist()


def extract_integers(s, max_float, bins=2048):
    ints = re.findall(r'[-+]?\d+', s)
    ints = [int(num) for num in ints]
    return (np.asarray(ints) - (bins // 2)) / (bins // 2) * max_float

def extract_floats(s):
    pattern = r"-?\d+\.\d*"
    float_numbers = re.findall(pattern, s)
    float_numbers = [float(num) for num in float_numbers]
    return np.asarray(float_numbers)

@dataclass
class MyCausalLMOutput:
    logits: Optional[torch.Tensor] = None
    loss: Optional[torch.Tensor] = None
    past_key_values: Optional[Tuple] = None
    hidden_states: Optional[Tuple] = None
    attentions: Optional[Tuple] = None


# [NEW] Helper function for Dilated Attention, adapted for autoregressive decoding
# @profile
def dilated_attention_forward(
    q, k, v,
    dilation_rate: int,
    segment_size: int,
    dropout_p: float = 0.0,
    is_causal: bool = True,  # Add is_causal as an argument
):
    """
    Performs dilated attention that is compatible with BOTH training (full sequence)
    and inference (single token) scenarios.

    Args:
        q (torch.Tensor): Query tensor (B, nhead, T_q, head_dim)
        k (torch.Tensor): Key tensor (B, nhead, T_kv, head_dim)
        v (torch.Tensor): Value tensor (B, nhead, T_kv, head_dim)
        dilation_rate (int): The dilation rate for skipping tokens.
        segment_size (int): The size of each segment for attention.
        dropout_p (float): Dropout probability.
        is_causal (bool): Whether to apply causal masking.

    Returns:
        torch.Tensor: The output of the attention mechanism.
    """
    B, nhead, T_q, head_dim = q.shape

    # The same dilation and segmentation logic is now applied to q, k, and v.
    tensors = [q, k, v]
    processed_tensors = []

    for t in tensors:
        T = t.shape[2]

        # 1. Pad the tensor to be divisible by dilation_rate.
        pad_amount = -T % dilation_rate
        padded_t = F.pad(t, (0, 0, 0, pad_amount))

        # 2. Reshape, dilate using unfold, and reshape back.
        rearranged_t = rearrange(padded_t, 'b h t d -> (b h) d t')
        dilated_t = rearranged_t.unfold(dimension=-1, size=1, step=dilation_rate)
        dilated_t = rearrange(dilated_t, '(b h) d t 1 -> b h t d', h=nhead)

        # 3. Pad the *dilated* sequence for segmentation.
        T_dilated = dilated_t.shape[2]
        pad_amount = -T_dilated % segment_size
        padded_dilated_t = F.pad(dilated_t, (0, 0, 0, pad_amount))

        # 4. Reshape into segments.
        segmented_t = rearrange(padded_dilated_t, 'b h (n s) d -> b h n s d', s=segment_size)
        processed_tensors.append(segmented_t)

    q_segmented, k_segmented, v_segmented = processed_tensors

    # 5. Perform scaled dot-product attention on the segmented tensors.
    # The shapes are now compatible for both training and inference.
    attn_out = F.scaled_dot_product_attention(
        q_segmented, k_segmented, v_segmented,
        dropout_p=dropout_p,
        is_causal=is_causal
    )

    # 6. Reshape the output back to the original sequence format.
    # (B, nhead, num_segments, segment_size, head_dim) -> (B, nhead, num_segments * segment_size, head_dim)
    attn_out = rearrange(attn_out, 'b h n s d -> b h (n s) d')

    # 7. Trim the padding that was added to the query `q`.
    # This is crucial to ensure the output has the correct sequence length.
    attn_out = attn_out[:, :, :T_q, :]

    return attn_out, pad_amount


class MLADecoderLayer(nn.Module):
    """A single Transformer decoder layer with RoPE & key/value caching."""

    def __init__(
            self,
            hidden_dim: int,
            nhead: int,
            rope: Optional[LlamaRotaryEmbedding] = None,
            cross_attention: bool = True,
            dropout: float = 0.1,
            # [NEW] Add LongNet specific parameters
            use_dilated_attention: bool = False,
            dilation_rates: list = [1, 2, 4, 6],
            segment_sizes: list = [512, 1024, 2048, 4096],
    ) -> None:
        super().__init__()

        if hidden_dim % nhead != 0:
            raise ValueError("hidden_dim must be divisible by nhead")
        self.hidden_dim = hidden_dim
        self.nhead = nhead
        self.head_dim = hidden_dim // nhead
        self.scale = self.head_dim ** -0.5
        self.cross_attention = cross_attention

        # [NEW] Store LongNet config
        self.use_dilated_attention = use_dilated_attention
        self.dilation_rates = dilation_rates
        self.segment_sizes = segment_sizes

        # self-attention projections (no changes here)
        self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)

        # cross-attention projections (optional, no changes here)
        if cross_attention:
            self.q_proj_cross = nn.Linear(hidden_dim, hidden_dim, bias=False)
            self.k_proj_cross = nn.Linear(hidden_dim, hidden_dim, bias=False)
            self.v_proj_cross = nn.Linear(hidden_dim, hidden_dim, bias=False)
            self.o_proj_cross = nn.Linear(hidden_dim, hidden_dim, bias=False)
            self.norm_cross = nn.LayerNorm(hidden_dim)

        # FFN (no changes here)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Linear(hidden_dim * 4, hidden_dim),
            nn.Dropout(dropout),
        )

        # layer norms (no changes here)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)

        # dropout (no changes here)
        self.drop_attn = nn.Dropout(dropout)

        # RoPE (no changes here)
        self.rope = rope

        # Latent KV Cache projections (no changes here)
        self.k_comp = nn.Parameter(torch.randn(self.head_dim, self.head_dim // 2))
        self.v_comp = nn.Parameter(torch.randn(self.head_dim, self.head_dim // 2))
        self.k_exp = nn.Parameter(torch.randn(nhead, self.head_dim // 2, self.head_dim))
        self.v_exp = nn.Parameter(torch.randn(nhead, self.head_dim // 2, self.head_dim))

        self.gate = nn.Parameter(torch.ones(len(dilation_rates)))

    # ... (_split_heads, _merge_heads, _compress_kv, _expand_kv methods are unchanged)
    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """(B, T, H) -> (B, T, nhead, head_dim)"""
        B, T, _ = x.size()
        return x.view(B, T, self.nhead, self.head_dim).transpose(1, 2)

    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        """(B, nhead, T, head_dim) -> (B, T, H)"""
        B, _, T, _ = x.size()
        return x.transpose(1, 2).contiguous().view(B, T, self.hidden_dim)

    def _compress_kv(self, k, v):
        # k,v: (B,n_head,T,head_dim)
        k_lat = torch.einsum('bhtd,dr->btr', k, self.k_comp)
        v_lat = torch.einsum('bhtd,dr->btr', v, self.v_comp)
        return k_lat, v_lat  # (B,T,r)

    def _expand_kv(self, k_lat, v_lat):
        # output per-head expanded k/v  (B,n_head,T,head_dim)
        k = torch.einsum('btr,hrd->bhtd', k_lat, self.k_exp)
        v = torch.einsum('btr,hrd->bhtd', v_lat, self.v_exp)
        return k, v

    def forward(
            self,
            x: torch.Tensor,
            position_ids: torch.Tensor,
            memory: Optional[torch.Tensor] = None,  # (B, L, H) or None
            memory_mask: Optional[torch.Tensor] = None,  # (B, 1, L) or None (1 = keep)
            past_key_value: Optional[
                Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]] = None,
            use_cache: bool = False,
            causal: bool = True,
    ):
        # ... (past_key_value loading logic is unchanged)
        B, T, _ = x.size()
        past_self_k, past_self_v = (None, None)
        past_cross_k, past_cross_v = (None, None)
        if past_key_value is not None:
            past_self_k, past_self_v = past_key_value[0]
            past_cross_tuple = past_key_value[1]
            if past_cross_tuple is not None:
                past_cross_k, past_cross_v = past_cross_tuple

        # ------------------------ Self-Attention ------------------------
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = self._split_heads(q)
        k = self._split_heads(k)
        v = self._split_heads(v)

        cos, sin = self.rope(x, position_ids)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        # [MODIFIED] Latent KV Cache logic is PRESERVED
        # The cache logic itself is not aware of the dilation.
        if past_self_k is not None:
            kv_compressed = self._compress_kv(k, v)
            k_lat = torch.cat([past_self_k, kv_compressed[0]], dim=1)
            v_lat = torch.cat([past_self_v, kv_compressed[1]], dim=1)
        else:
            k_lat, v_lat = self._compress_kv(k, v)

        # Dilation is applied AFTER expanding the full history.
        k_full, v_full = self._expand_kv(k_lat, v_lat)

        past_len = 0 if past_self_k is None else past_self_k.size(1)

        # [MODIFIED] Main Attention Calculation
        # Original standard attention path
        attn_mask = None
        T = x.size(1)
        total_len = past_len + T
        if causal:
            attn_mask = torch.ones((T, total_len), dtype=torch.bool, device=x.device)
            attn_mask = torch.triu(attn_mask, diagonal=1 + past_len)
            attn_mask = ~attn_mask
            attn_mask = attn_mask.unsqueeze(0).unsqueeze(0)

        attn_out = F.scaled_dot_product_attention(
            q, k_full, v_full, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
        )

        # Merge heads and output projection (unchanged)
        attn_out = self._merge_heads(attn_out)
        attn_out = self.o_proj(attn_out)
        x = x + self.drop_attn(attn_out)
        x = self.norm1(x)

        # Cache saving logic (unchanged)
        new_self_k, new_self_v = (k_lat, v_lat) if use_cache else (None, None)

        # ------------------------ Cross-Attention ------------------------
        # (The entire cross-attention block is unchanged)
        if self.cross_attention and memory is not None:
            q = self.q_proj_cross(x)
            q = self._split_heads(q)

            # recompute_cross = (past_cross_k is None) or (memory is not None and past_cross_k.size(-2) != memory.size(1))
            # if recompute_cross:
            #     cross_k = self._split_heads(self.k_proj_cross(memory))
            #     cross_v = self._split_heads(self.v_proj_cross(memory))
            # else:
            #     cross_k, cross_v = past_cross_k, past_cross_v
            if past_cross_k is None:
                # project only once per prompt (during prefilling or first step)
                cross_k = self.k_proj_cross(memory)
                cross_v = self.v_proj_cross(memory)
                cross_k = self._split_heads(cross_k)
                cross_v = self._split_heads(cross_v)
            else:
                cross_k, cross_v = past_cross_k, past_cross_v

            # (optionally) apply RoPE to query only - keys from memory are not rotatable
            cos_q, sin_q = self.rope(v, position_ids)  # same pos ids
            q, _ = apply_rotary_pos_emb(q, q, cos_q, sin_q)

            if memory_mask is not None:
                # broadcast to (B, 1, 1, L)
                ext_mask = memory_mask.unsqueeze(1).unsqueeze(2)
            else:
                ext_mask = None

            cross_out = F.scaled_dot_product_attention(
                q, cross_k, cross_v, attn_mask=ext_mask, dropout_p=0.0, is_causal=False
            )
            cross_out = self._merge_heads(cross_out)
            cross_out = self.o_proj_cross(cross_out)
            x = x + self.drop_attn(cross_out)
            x = self.norm_cross(x)

            new_cross_k, new_cross_v = (cross_k, cross_v) if use_cache else (None, None)
        else:
            new_cross_k, new_cross_v = (None, None)

        # --------------------------- FFN ---------------------------
        # (The entire FFN block is unchanged)
        x = x + self.ffn(x)
        x = self.norm2(x)

        # Return logic (unchanged)
        new_past = None
        if use_cache:
            new_past = ((new_self_k, new_self_v),
                        (new_cross_k, new_cross_v)
                        )

        return x, new_past


class MLAEncoderLayer(nn.Module):
    """A single Transformer decoder layer with RoPE & key/value caching."""

    def __init__(
            self,
            hidden_dim: int,
            nhead: int,
            dropout: float = 0.1
    ) -> None:
        super().__init__()

        if hidden_dim % nhead != 0:
            raise ValueError("hidden_dim must be divisible by nhead")
        self.hidden_dim = hidden_dim
        self.nhead = nhead
        self.head_dim = hidden_dim // nhead
        self.scale = self.head_dim ** -0.5

        # self-attention projections (no changes here)
        self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.o_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)

        # FFN (no changes here)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Linear(hidden_dim * 4, hidden_dim),
            nn.Dropout(dropout),
        )

        # layer norms (no changes here)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)

        # dropout (no changes here)
        self.drop_attn = nn.Dropout(dropout)

    # ... (_split_heads, _merge_heads, _compress_kv, _expand_kv methods are unchanged)
    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """(B, T, H) -> (B, T, nhead, head_dim)"""
        B, T, _ = x.size()
        return x.view(B, T, self.nhead, self.head_dim).transpose(1, 2)

    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        """(B, nhead, T, head_dim) -> (B, T, H)"""
        B, _, T, _ = x.size()
        return x.transpose(1, 2).contiguous().view(B, T, self.hidden_dim)


class ScaledLlamaRotaryEmbedding(nn.Module):
    def __init__(
        self,
        lm_config
    ):
        super().__init__()
        self.config = lm_config
        self.dim = lm_config.head_dim
        self.base = lm_config.rope_theta
        self.attention_scaling = 1.0

        self.inv_freq = self._build_inv_freq_llama()
        # self.inv_freq = self._build_inv_freq_ntk()

    def _build_inv_freq_llama(self):

        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))

        factor = 64.0   # self.config.rope_scaling["factor"]  # `8` in the original implementation
        low_freq_factor = 1.0   # self.config.rope_scaling["low_freq_factor"]  # `1` in the original implementation
        high_freq_factor = 4.0  # self.config.rope_scaling["high_freq_factor"]  # `4` in the original implementation
        old_context_len = 32768  # self.config.rope_scaling["original_max_position_embeddings"]  # `8192` in the original implementation

        low_freq_wavelen = old_context_len / low_freq_factor
        high_freq_wavelen = old_context_len / high_freq_factor

        wavelen = 2 * math.pi / inv_freq
        # wavelen < high_freq_wavelen: do nothing
        # wavelen > low_freq_wavelen: divide by factor
        inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
        # otherwise: interpolate between the two, using a smooth factor
        smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
        smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
        is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
        inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
        return inv_freq_llama

    def _build_inv_freq_ntk(self, seq_len=None):
        base = self.config.rope_theta
        partial_rotary_factor = self.config.partial_rotary_factor if hasattr(self.config, "partial_rotary_factor") else 1.0
        head_dim = self.dim
        dim = int(head_dim * partial_rotary_factor)
        max_position_embeddings = self.config.max_position_embeddings
        factor = 2  # self.config.rope_scaling["factor"]

        attention_factor = 1.0  # Unused in this type of RoPE

        # seq_len: default to max_position_embeddings, e.g. at init time
        if seq_len is None:
            seq_len = max_position_embeddings
        elif isinstance(seq_len, torch.Tensor):
            seq_len = torch.maximum(
                seq_len,
                torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
            )
        else:
            seq_len = max(seq_len, max_position_embeddings)

        # Compute the inverse frequencies
        base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
        inv_freq = 1.0 / (
                    base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().cuda() / dim))
        return inv_freq

    def forward(self, x, position_ids=None):
        # x: (B, T, dim) or (B, H, T, dim)
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
        position_ids_expanded = position_ids[:, None, :].float()

        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):  # Force float32
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos() * self.attention_scaling
            sin = emb.sin() * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
    """
    Apply rotary position embedding to q and k tensors.

    Args:
        q, k: (B, H, T, D)
        cos, sin: (1, T, D) or broadcastable to q/k
        position_ids: (B, T) or None

    Returns:
        q_rot, k_rot: (B, H, T, D)
    """
    if position_ids is not None:
        # shape: (B, T, 1, 1) ? for broadcasting to (B, H, T, D)
        cos = cos[0, position_ids].unsqueeze(2)  # (B, T, 1, D)
        sin = sin[0, position_ids].unsqueeze(2)
    else:
        # expand to (1, 1, T, D) to match (B, H, T, D)
        cos = cos.unsqueeze(1)  # (1, 1, T, D)
        sin = sin.unsqueeze(1)

    def rotate_half(x):
        x1 = x[..., ::2]
        x2 = x[..., 1::2]
        return torch.cat([-x2, x1], dim=-1)

    q_rot = (q * cos) + (rotate_half(q) * sin)
    k_rot = (k * cos) + (rotate_half(k) * sin)
    return q_rot, k_rot

def edge_index_to_adj_list(edge_index: torch.Tensor, num_nodes: int,):
    adj_matrix = to_dense_adj(edge_index.squeeze(0), max_num_nodes=num_nodes)[0]

    adj = [[] for _ in range(num_nodes)]
    for u in range(num_nodes):
        non_zeros = torch.where(adj_matrix[u]>0)[0].tolist()
        adj[u].append(u)
        for v in non_zeros:
            adj[u].append(v)
            adj[v].append(u)
    return adj
