import torch
import torch.nn as nn
from einops import rearrange, repeat

from memKNO.network import TransformerCatNoCls, FeedForward
from memKNO.attention import CrossLinearAttention


# ---------- Encoder: spatial Transformer + latent cross-attend ----------
class LatentGlobalEncoder2D(nn.Module):
    """
    Pipeline:
      1) to_embedding: [B, N_pt, in_channels] -> [B, N_pt, in_emb_dim]
      2) spatial Transformer (galerkin/fourier linear attention) over points for `spatial_depth` layers
      3) create K learned latent tokens (+ optional learned latent positions)
      4) stack `latent_depth` blocks:
            latents = CrossLinearAttention(latents <- points, with RoPE) + residual
            latents = FFN(latents) + residual
      5) project_to_latent: [B, K, in_emb_dim] -> [B, K, token_dim]
    Output: global latents [B, K, token_dim]
    """
    def __init__(self,
                 input_channels,
                 in_emb_dim,
                 token_dim,
                 heads,
                 spatial_depth,                  # how many spatial Transformer layers
                 dim_head=None,                  # per-head dim (default: in_emb_dim // heads)
                 mlp_dim=None,                   # FFN hidden dim (default: 2*in_emb_dim)
                 attn_type='galerkin',
                 dropout=0.,
                 # spatial RoPE settings
                 relative_emb_dim=2,
                 min_freq=1/64,
                 scale_spatial=None,             # list len=spatial_depth, e.g. [32,16,8,1]
                 use_ln=True,
                 # latent settings
                 latent_tokens=4,                # K
                 latent_depth=2,                 # how many cross-attn blocks
                 use_latent_ln=True,
                 use_latent_pos=True,            # give latent tokens their own learned positions
                 scale_latent=8.0,               # RoPE scale in cross-attn for latents
                 ):
        super().__init__()
        assert attn_type in ['galerkin', 'fourier']
        self.pos_dim = relative_emb_dim
        self.use_latent_pos = use_latent_pos

        if dim_head is None:
            assert in_emb_dim % heads == 0
            dim_head = in_emb_dim // heads
        if mlp_dim is None:
            mlp_dim = in_emb_dim * 2

        # 0) point-wise embedding
        self.to_embedding = nn.Linear(input_channels, in_emb_dim, bias=False)

        # 1) spatial transformer over points
        if scale_spatial is None:
            # default: coarse->fine
            if spatial_depth <= 2:
                scale_spatial = [32, 1][:spatial_depth]
            elif spatial_depth == 3:
                scale_spatial = [32, 16, 1]
            else:
                scale_spatial = [32, 16] + [8]*(spatial_depth-3) + [1]
        assert len(scale_spatial) == spatial_depth

        self.s_transformer = TransformerCatNoCls(
            dim=in_emb_dim,
            depth=spatial_depth,
            heads=heads,    ##
            dim_head=dim_head,
            mlp_dim=in_emb_dim,            # keep same channel for residual convenience
            attn_type=attn_type,
            use_ln=use_ln,
            scale=scale_spatial,
            dropout=dropout,
            relative_emb_dim=relative_emb_dim,
            min_freq=min_freq,
            attention_init='orthogonal',
            init_gain=None,
            use_relu=False,
            cat_pos=False,
        )

        # 2) learned latent tokens (+ positions)
        self.latent_tokens = nn.Parameter(torch.randn(latent_tokens, in_emb_dim) * (1 / in_emb_dim))
        if use_latent_pos:
            # initialize in [0,1]
            self.latent_pos = nn.Parameter(torch.rand(latent_tokens, self.pos_dim))

        # 3) latent cross-attn blocks
        self.lat_blocks = nn.ModuleList([])
        for _ in range(latent_depth):
            cross = CrossLinearAttention(
                dim_q=in_emb_dim, dim_kv=in_emb_dim,
                heads=heads, dim_head=dim_head,
                attn_type=attn_type, dropout=dropout,
                relative_emb=True,                # use RoPE in cross-attn
                relative_emb_dim=relative_emb_dim,
                min_freq=min_freq,
                scale=scale_latent,               # usually a bit smaller than spatial's first layer
                cat_pos=False,
                pos_dim=self.pos_dim
            )
            ffn = FeedForward(in_emb_dim, mlp_dim, dropout=dropout)
            if use_latent_ln:
                block = nn.ModuleList([nn.LayerNorm(in_emb_dim), cross,
                                       nn.LayerNorm(in_emb_dim), ffn])
            else:
                block = nn.ModuleList([cross, ffn])
            self.lat_blocks.append(block)

        # 4) projection to decoder latent dim
        self.project_to_latent = nn.Linear(in_emb_dim, token_dim, bias=False)

        self.use_latent_ln = use_latent_ln


    def forward(self, x: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
        """
        x         : [B, N_pt, input_channels]  (e.g., T_in*C_in + 2 if you've concatenated coords upstream)
        input_pos : [B, N_pt, pos_dim]         (usually 2D coords)
        return    : [B, K, token_dim] (K: number of tokens; token_dim: dimension of the cls token)
        """
        B, N, _ = x.shape
        # 0) embed points
        x = self.to_embedding(x)                                  # [B,N,in_emb_dim]
        # 1) spatial transformer over points
        x = self.s_transformer(x, input_pos)                      # [B,N,in_emb_dim]

        # 2) prepare latent tokens (+ positions)
        lat = repeat(self.latent_tokens, 't c -> b t c', b=B)     # [B,K,in_emb_dim]
        if self.use_latent_pos:
            latent_pos = repeat(self.latent_pos, 't c -> b t c', b=B)  # [B,K,pos_dim]
        else:
            # zeros if you don't want learned latent positions
            latent_pos = torch.zeros(B, self.latent_tokens.shape[0], self.pos_dim,
                                        device=x.device, dtype=x.dtype)

        # 3) latent cross-attn stack
        for blk in self.lat_blocks:
            if self.use_latent_ln:
                ln1, cross, ln2, ffn = blk
                lat = ln1(lat)
                lat = cross(lat, x, pos_q=latent_pos, pos_kv=input_pos) + lat
                lat = ln2(lat)
                lat = ffn(lat) + lat
            else:
                cross, ffn = blk
                lat = cross(lat, x, pos_q=latent_pos, pos_kv=input_pos) + lat
                lat = ffn(lat) + lat

        # 4) project to decoder latent dim
        lat = self.project_to_latent(lat)                         # [B,K,token_dim]
        return lat
    


################################ Set Transformer based Encoder ################################
from memKNO.network import TrainablePosEncoder, FourierPosEncoder, PreEncoder, SetTransformer

class SetEncoder2D(nn.Module):
    def __init__(self,
                 input_channels: int,
                 pos_emb_dim: int,
                 pos_emb_type: str = "trainable",
                 pos_hidden: int = 256,
                 val_hidden: int = 128,
                 set_dim: int = 128, 
                 set_hidden: int = 128,
                 num_heads: int = 4,
                 num_inds: int = 64, 
                 token_dim: int = 64,
                 latent_tokens: int = 4,                # K
                 use_ln: bool = True,
                 fourier_max_freq: float = 16.0,
                 dropout: float = 0.1,
                 ):
        super().__init__()
        assert pos_emb_type in ["trainable", "fourier"]
        # (B, N_PT, 2) -> (B, N_PT, POS_EMB_DIM)
        if pos_emb_type == "trainable":
            self.pos_encoder = TrainablePosEncoder(pos_dim=pos_emb_dim, num_frequencies=64)
        elif pos_emb_type == "fourier":
            self.pos_encoder = FourierPosEncoder(num_bands=pos_emb_dim//4, max_freq=fourier_max_freq)
        self.pre_encoder = PreEncoder(pos_in_dim=pos_emb_dim, val_in_dim=input_channels, 
                                      pos_hidden=pos_hidden, val_hidden=val_hidden, out_dim=set_dim, dropout=dropout)
        # (B, N_PT, SET_DIM) -> (B, K, token_dim)
        self.set_encoder = SetTransformer(dim_input=set_dim, num_outputs=latent_tokens, dim_output=token_dim,
                                          num_inds=num_inds, dim_hidden=set_hidden, num_heads=num_heads, ln=use_ln)
        
    
    def forward(self, x: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
        """
        x         : [B, N_pt, input_channels]  (e.g., T_in*C_in + 2 if you've concatenated coords upstream)
        input_pos : [B, N_pt, pos_dim]         (usually 2D coords)
        return    : [B, K, token_dim] (K: number of tokens; token_dim: dimension of the cls token)
        """
        pos_emb = self.pos_encoder(input_pos)
        set_tokens = self.pre_encoder(pos_emb, x)
        out = self.set_encoder(set_tokens)
        return out




