#!/usr/bin/env python
#-*- coding:utf-8 _*-
import math
import numpy as np
import torch
import torch.nn as nn
import dgl
from einops import repeat, rearrange
from torch.nn import functional as F
from torch.nn import GELU, ReLU, Tanh, Sigmoid
from torch.nn.utils.rnn import pad_sequence

from utils import MultipleTensors
from models.mlp import MLP

# ==============================================================================
# 辅助模块 (为了使文件独立，在这里完整定义)
# ==============================================================================
try:
    import xformers.ops as xops
    XFORMERS_AVAILABLE = True
except ImportError:
    XFORMERS_AVAILABLE = False

class GPTConfig():
    """CGPT 模型的配置类。"""
    def __init__(self, attn_type='linear', embd_pdrop=0.0, resid_pdrop=0.0, attn_pdrop=0.0,
                 n_embd=128, n_head=1, n_layer=3, block_size=128, n_inner=4, act='gelu',
                 branch_sizes=None, n_inputs=1):
        self.attn_type = attn_type
        self.embd_pdrop = embd_pdrop
        self.resid_pdrop = resid_pdrop
        self.attn_pdrop = attn_pdrop
        self.n_embd = n_embd
        self.n_head = n_head
        self.n_layer = n_layer
        self.block_size = block_size
        self.n_inner = n_inner * self.n_embd if n_inner >= 1 else int(n_inner * n_embd)
        self.act = act
        self.branch_sizes = branch_sizes
        self.n_inputs = n_inputs

def horizontal_fourier_embedding(X, n=3):
    """为输入张量添加水平傅里叶特征。"""
    if n <= 0:
        return X
    freqs = 2**torch.linspace(-n, n, 2*n+1).to(X.device)
    freqs = freqs[None, None, None, ...]
    X_ = X.unsqueeze(-1).repeat(1, 1, 1, 2*n+1)
    X_cos = torch.cos(freqs * X_)
    X_sin = torch.sin(freqs * X_)
    X = torch.cat([X.unsqueeze(-1), X_cos, X_sin], dim=-1).view(X.shape[0], X.shape[1], -1)
    return X

class LinearAttention(nn.Module):
    """多头线性自注意力层 (Self-Attention)。"""
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.value = nn.Linear(config.n_embd, config.n_embd)
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        self.n_head = config.n_head

    def forward(self, x, kv_cache=None):
        B, T, C = x.size()
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        
        q = q.softmax(dim=-1)
        k_softmax = k.softmax(dim=-1)
        
        kv = torch.einsum('bhtd,bhvd->bhtv', k_softmax, v)
        qkv = torch.einsum('bhtd,bhtv->bhtv', q, kv)
        D_inv = 1. / torch.clamp(torch.einsum('bhtd,bhud->bhtu', q, k_softmax.sum(dim=2, keepdim=True)), min=1e-6)
        
        y = qkv * D_inv
        y = rearrange(y, 'b h n d -> b n (h d)')
        return self.proj(y), None

class LinearCrossAttention(nn.Module):
    """多头线性交叉注意力层 (Cross-Attention)。"""
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.keys = nn.ModuleList([nn.Linear(config.n_embd, config.n_embd) for _ in range(config.n_inputs)])
        self.values = nn.ModuleList([nn.Linear(config.n_embd, config.n_embd) for _ in range(config.n_inputs)])
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        self.n_head = config.n_head
        self.n_inputs = config.n_inputs

    def forward(self, x, y_list):
        B, T1, C = x.size()
        q = self.query(x).view(B, T1, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.softmax(dim=-1)
        
        out = torch.zeros_like(q)
        for i in range(self.n_inputs):
            y_i = y_list[i]
            _, T2, _ = y_i.size()
            k = self.keys[i](y_i).view(B, T2, self.n_head, C // self.n_head).transpose(1, 2)
            v = self.values[i](y_i).view(B, T2, self.n_head, C // self.n_head).transpose(1, 2)
            k_softmax = k.softmax(dim=-1)
            
            kv = torch.einsum('bhtd,bhvd->bhtv', k_softmax, v)
            qkv = torch.einsum('bhtd,bhtv->bhtv', q, kv)
            D_inv = 1. / torch.clamp(torch.einsum('bhtd,bhud->bhtu', q, k_softmax.sum(dim=2, keepdim=True)), min=1e-6)
            out += qkv * D_inv

        y = rearrange(out, 'b h n d -> b n (h d)')
        return self.proj(y)

class CrossAttentionBlock(nn.Module):
    """CGPT 的标准注意力模块。"""
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2_branch = nn.ModuleList([nn.LayerNorm(config.n_embd) for _ in range(config.n_inputs)])
        self.ln3 = nn.LayerNorm(config.n_embd)
        self.ln4 = nn.LayerNorm(config.n_embd)
        self.ln5 = nn.LayerNorm(config.n_embd)
        self.n_inputs = config.n_inputs
        
        self.selfattn = LinearAttention(config)
        self.crossattn = LinearCrossAttention(config)
        
        if config.act == 'gelu': self.act = GELU
        elif config.act == "tanh": self.act = Tanh
        elif config.act == 'relu': self.act = ReLU
        elif config.act == 'sigmoid': self.act = Sigmoid

        self.resid_drop = nn.Dropout(config.resid_pdrop)
        
        self.mlp1 = nn.Sequential(
            nn.Linear(config.n_embd, config.n_inner),
            self.act(),
            nn.Linear(config.n_inner, config.n_embd),
        )
        self.mlp2 = nn.Sequential(
            nn.Linear(config.n_embd, config.n_inner),
            self.act(),
            nn.Linear(config.n_inner, config.n_embd),
        )

    def ln_branchs(self, y):
        return [self.ln2_branch[i](y[i]) for i in range(self.n_inputs)]

    def forward(self, x, y, self_attn_cache=None):
        x = x + self.resid_drop(self.crossattn(self.ln1(x), self.ln_branchs(y)))
        x = x + self.resid_drop(self.mlp1(self.ln3(x)))
        
        attn_out, new_self_attn_cache = self.selfattn(self.ln4(x), kv_cache=self_attn_cache)
        x = x + self.resid_drop(attn_out)
        x = x + self.resid_drop(self.mlp2(self.ln5(x)))
        return x, new_self_attn_cache

class XFAttention(nn.Module):
    """使用 xformers 的高效注意力实现"""
    def __init__(self, n_embd, n_head, attn_pdrop=0.0):
        super().__init__()
        assert n_embd % n_head == 0
        self.key = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        self.proj = nn.Linear(n_embd, n_embd)
        self.n_head = n_head
        self.attn_pdrop = attn_pdrop

    def forward(self, x, y=None, attn_bias=None):
        y = x if y is None else y
        B, N, C = x.shape
        
        q = self.query(x).view(B, N, self.n_head, C // self.n_head)
        k = self.key(y).view(B, y.shape[1], self.n_head, C // self.n_head)
        v = self.value(y).view(B, y.shape[1], self.n_head, C // self.n_head)

        if XFORMERS_AVAILABLE:
            out = xops.memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=self.attn_pdrop)
        else: # Fallback
            q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d'), (q, k, v))
            scores = torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5)
            if attn_bias is not None: scores += attn_bias
            attn = F.softmax(scores, dim=-1)
            attn = F.dropout(attn, p=self.attn_pdrop)
            out = torch.matmul(attn, v)
            out = rearrange(out, 'b h n d -> b n h d')
        
        out = rearrange(out, 'b n h d -> b n (h d)')
        return self.proj(out)

class PhysicsSlicer(nn.Module): ...
class PhysicsDeslicer(nn.Module): ...
class GatedMLP(nn.Module): ...
class LightweightRouter(nn.Module): ...
def pack_tokens(x, active_idx): ...

# ==============================================================================
# CGPTNO 模型 (最终版)
# ==============================================================================
class CGPTNO(nn.Module):
    def __init__(self,
                 coord_dim, state_dim, theta_dim, branch_sizes, output_size,
                 n_layers=2, n_hidden=64, n_head=1, n_inner=4, mlp_layers=2,
                 attn_type='linear', act='gelu', ffn_dropout=0.0,
                 attn_dropout=0.0, horiz_fourier_dim=0, **kwargs):
        super(CGPTNO, self).__init__()
        self.__name__ = 'CGPT_Fixed'
        self.horiz_fourier_dim = horiz_fourier_dim
        self.n_inputs = len(branch_sizes) if branch_sizes else 0
        self.output_size = output_size

        fourier_multiplier = (4 * horiz_fourier_dim + 3) if horiz_fourier_dim > 0 else 1
        trunk_input_dim = (coord_dim + state_dim) * fourier_multiplier + theta_dim
        
        self.trunk_mlp = MLP(trunk_input_dim, n_hidden, n_hidden, n_layers=mlp_layers, act=act)
        if self.n_inputs > 0:
            self.branch_mlps = nn.ModuleList([MLP(bsize, n_hidden, n_hidden, n_layers=mlp_layers, act=act) for bsize in branch_sizes])

        self.gpt_config = GPTConfig(
            attn_type=attn_type, embd_pdrop=ffn_dropout, resid_pdrop=ffn_dropout,
            attn_pdrop=attn_dropout, n_embd=n_hidden, n_head=n_head, 
            n_layer=n_layers, block_size=128, act=act, 
            branch_sizes=branch_sizes, n_inputs=self.n_inputs, n_inner=n_inner
        )
        self.blocks = nn.ModuleList([CrossAttentionBlock(self.gpt_config) for _ in range(self.gpt_config.n_layer)])
        self.out_mlp = MLP(n_hidden, n_hidden, output_size, n_layers=mlp_layers)

    def forward(self, g, coords, u_p, inputs, past_key_values=None):
        B, N, _ = coords.shape
        state = g.ndata['x'].view(B, N, -1)
        
        trunk_features = torch.cat([coords, state], dim=-1)
        if self.horiz_fourier_dim > 0:
            trunk_features = horizontal_fourier_embedding(trunk_features, self.horiz_fourier_dim)
        
        u_p_expanded = u_p.unsqueeze(1).repeat(1, N, 1)
        trunk_input = torch.cat([trunk_features, u_p_expanded], dim=-1)
        
        x_trunk_encoded = self.trunk_mlp(trunk_input)

        if self.n_inputs > 0 and inputs and len(inputs) > 0:
            z = MultipleTensors([self.branch_mlps[i](inputs[i]) for i in range(self.n_inputs)])
        else:
            z = MultipleTensors([x_trunk_encoded])

        if past_key_values is None:
            past_key_values = [None] * len(self.blocks)
        
        x = x_trunk_encoded
        new_past_key_values = []
        for i, block in enumerate(self.blocks):
            x, new_cache = block(x, z, self_attn_cache=past_key_values[i])
            new_past_key_values.append(new_cache)

        x_out_padded = self.out_mlp(x)
        x_out = torch.cat([x_out_padded[i, :num] for i, num in enumerate(g.batch_num_nodes())], dim=0)

        return x_out, new_past_key_values

# ==============================================================================
# StructuredRecursiveGNOT 模型 (最终版)
# ==============================================================================
class StructuredRecursiveGNOT(nn.Module):
    def __init__(self,
                 coord_dim, state_dim, theta_dim, branch_sizes, output_size,
                 n_layers, n_hidden, n_head,
                 num_fine_nodes, num_coarse_nodes,
                 final_keep_ratio=0.25, 
                 n_inner=4, mlp_layers=2, act='gelu',
                 ffn_dropout=0.0, attn_dropout=0.0, **kwargs):
        super().__init__()
        self.__name__ = 'StructuredRecursiveGNOT_Fixed'
        
        self.recursion_depth = n_layers
        self.n_hidden = n_hidden
        self.n_inputs = len(branch_sizes) if branch_sizes else 0
        
        ratios = np.linspace(1.0, final_keep_ratio, self.recursion_depth)
        self.capacity_factors = [int(num_coarse_nodes * r) for r in ratios]
        self.capacity_factors[-1] = max(self.capacity_factors[-1], 1)
        
        trunk_input_dim = coord_dim + state_dim + theta_dim
        self.trunk_mlp = MLP(trunk_input_dim, n_hidden, n_hidden, n_layers=mlp_layers, act=act)
        
        if self.n_inputs > 0:
            self.branch_mlps = nn.ModuleList([MLP(bsize, n_hidden, n_hidden, n_layers=mlp_layers, act=act) for bsize in branch_sizes])
        self.out_mlp = MLP(n_hidden, n_hidden, output_size, n_layers=mlp_layers)
        
        self.slicer = PhysicsSlicer(n_hidden, num_coarse_nodes)
        self.deslicer = PhysicsDeslicer()
        
        self.cross_attention = XFAttention(n_hidden, n_head, attn_pdrop=attn_dropout)
        self.cross_attn_norm = nn.LayerNorm(n_hidden)
        
        self.router = LightweightRouter(n_hidden)
        
        config = {'n_embd': n_hidden, 'n_head': n_head, 'n_inner': n_inner * n_hidden, 
                  'resid_pdrop': ffn_dropout, 'attn_pdrop': attn_dropout}
        self.recursion_blocks = nn.ModuleList([self._create_block(config) for _ in range(self.recursion_depth)])

    def _create_block(self, config):
        class CoreBlock(nn.Module):
            def __init__(self, cfg):
                super().__init__()
                self.ln1 = nn.LayerNorm(cfg['n_embd'])
                self.ln2 = nn.LayerNorm(cfg['n_embd'])
                self.attn = XFAttention(cfg['n_embd'], cfg['n_head'], cfg['attn_pdrop'])
                self.mlp = GatedMLP(cfg['n_embd'], cfg['n_inner'], cfg['n_embd'])
                self.resid_drop = nn.Dropout(cfg['resid_pdrop'])
            
            def forward(self, x):
                x = x + self.resid_drop(self.attn(self.ln1(x)))
                x = x + self.resid_drop(self.mlp(self.ln2(x)))
                return x
        return CoreBlock(config)
    
    def forward(self, g, coords, u_p, inputs, **kwargs):
        B, N_fine, _ = coords.shape
        state = g.ndata['x'].view(B, N_fine, -1)
        C_hidden = self.n_hidden

        u_p_expanded = u_p.unsqueeze(1).repeat(1, N_fine, 1)
        trunk_input = torch.cat([coords, state, u_p_expanded], dim=-1)
        
        hidden_states_fine = self.trunk_mlp(trunk_input)
        
        if self.n_inputs > 0 and inputs and len(inputs) > 0:
            z_fine_list = [self.branch_mlps[i](inputs.x[i]) for i in range(self.n_inputs)]
            z_fine = torch.mean(torch.stack(z_fine_list), dim=0)
        else:
            z_fine = hidden_states_fine
        
        h_coarse, slice_weights = self.slicer(hidden_states_fine)
        z_coarse, _ = self.slicer(z_fine)
        
        cross_features = self.cross_attention(self.cross_attn_norm(h_coarse), y=z_coarse)
        hidden_states = h_coarse + cross_features
        
        router_scores = self.router(hidden_states).squeeze(-1)
        with torch.no_grad():
            _, global_ranking_indices = torch.sort(router_scores, dim=1, descending=True)
        
        for depth in range(self.recursion_depth):
            k = self.capacity_factors[depth]
            active_indices = global_ranking_indices[:, :k]
            active_h = pack_tokens(hidden_states, active_indices)
            block_output_active = self.recursion_blocks[depth](active_h)
            idx_exp = active_indices.unsqueeze(-1).expand(-1, -1, C_hidden)
            # 修复混合精度下的 dtype 冲突
            hidden_states = hidden_states.scatter(
                1, 
                idx_exp, 
                block_output_active.to(hidden_states.dtype)
            )
            
        final_states_fine_update = self.deslicer(hidden_states, slice_weights)
        final_states_fine = hidden_states_fine + final_states_fine_update
        output_padded = self.out_mlp(final_states_fine)
        x_out = torch.cat([output_padded[i, :num] for i, num in enumerate(g.batch_num_nodes())], dim=0)
        
        return x_out, None