import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, GenerationConfig
from torch_geometric.nn import GCNConv, TransformerConv
from torch_geometric.data import Data
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator, PartialState
from utils import extract_floats, RegressionOutput, MLADecoderLayer, MyCausalLMOutput, MLAEncoderLayer
from typing import Optional
import re
import math
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from typing import Optional, List, Tuple, Union



class NodalEncoder(nn.Module):
    def __init__(self, input_dim, output_dim, num_layers=2, hidden_dim=32, kernel_size=3):
        super(NodalEncoder, self).__init__()

        self.layers = nn.ModuleList()

        for i in range(num_layers):
            in_channels = input_dim if i == 0 else hidden_dim
            out_channels = hidden_dim
            conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=1, bias=False)
            self.layers.append(conv)

        self.adaptive_pool = nn.AdaptiveAvgPool1d(output_dim)
        self.layer_norm = nn.LayerNorm(output_dim)

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x_res = x
            x = F.silu(layer(x))
            if i > 0:
                x = x + x_res
        x = self.adaptive_pool(x)
        x = self.layer_norm(x)
        x = x.view(x.size(0), -1)
        return x


# Step 1: Graph Encoder (Using GCN)
class GraphEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=3):
        super(GraphEncoder, self).__init__()

        self.layers = nn.ModuleList()

        for i in range(num_layers):
            in_channels = input_dim if i == 0 else hidden_dim
            out_channels = hidden_dim
            conv = GCNConv(in_channels, out_channels)
            self.layers.append(conv)

    def forward(self, x, edge_index):
        for i, layer in enumerate(self.layers):
            x_res = x
            x = F.silu(layer(x, edge_index))
            if i > 0:
                x = x + x_res
        return x


# Step 2: Language Encoder (Using LLaMA)
class LanguageEncoder(nn.Module):
    def __init__(self, model, use_lora=False):
        super(LanguageEncoder, self).__init__()
        self.model = model
        self.use_lora = use_lora

    def forward(self, inputs):
        outputs = self.model.get_input_embeddings()(inputs.input_ids.long().cuda())
        return outputs


class GlobalAttention(nn.Module):
    def __init__(self, language_layers, graph_layers, rope, head_dim, use_lora=True):
        super(GlobalAttention, self).__init__()
        self.language_layers = language_layers.model.layers if not use_lora else language_layers.get_base_model().model.layers
        self.graph_layers = graph_layers
        self.rope = rope
        self.head_dim = head_dim

    def llama_layer_forward(self, idx, x):
        input_shape = x.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)
        position_ids = torch.arange(x.shape[1], device='cuda').unsqueeze(0).expand(x.size(0), -1)
        cos, sin = self.rope(x, position_ids)
        embedding = self.language_layers[idx].input_layernorm(x)
        q = self.language_layers[idx].self_attn.q_proj(embedding).view(hidden_shape).transpose(1, 2)
        k = self.language_layers[idx].self_attn.k_proj(embedding).view(hidden_shape).transpose(1, 2)
        v = self.language_layers[idx].self_attn.v_proj(embedding).view(hidden_shape).transpose(1, 2)
        k = self.repeat_kv(k, q.shape[1] // k.shape[1])
        v = self.repeat_kv(v, q.shape[1] // v.shape[1])
        q, k = apply_rotary_pos_emb(q, k, cos, sin)
        attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        attn_out = attn_out.reshape(attn_out.shape[0], attn_out.shape[2], -1).contiguous()
        attn = self.language_layers[idx].self_attn.o_proj(attn_out)
        x = x + attn
        embedding = self.language_layers[idx].post_attention_layernorm(x)

        # feedforward & skip connection
        embedding = self.language_layers[idx].mlp(embedding)
        x = x + embedding

        return x

    # Llama style pre-norm
    def forward(self, x_graph, x_lang):
        graph_seq_len, lang_seq_len = x_graph.size(1), x_lang.size(1)
        position_ids = torch.arange(graph_seq_len + lang_seq_len, device='cuda').unsqueeze(0).expand(x_graph.size(0), -1)
        cos, sin = self.rope(torch.cat((x_graph, x_lang), dim=1), position_ids)  # rope.forward returns (cos, sin)
        graph_input_shape = x_graph.shape[:-1]
        lang_input_shape = x_lang.shape[:-1]
        graph_hidden_shape = (*graph_input_shape, -1, self.head_dim)
        lang_hidden_shape = (*lang_input_shape, -1, self.head_dim)
        for i, graph_layer in enumerate(self.graph_layers):

            # pre-norm
            with torch.no_grad():
                lang_embedding = self.language_layers[i].input_layernorm(x_lang)
            graph_embedding = graph_layer.norm1(x_graph)

            # proj by modality
            q_graph = graph_layer.q_proj(graph_embedding).view(graph_hidden_shape).transpose(1, 2)
            k_graph = graph_layer.k_proj(graph_embedding).view(graph_hidden_shape).transpose(1, 2)
            v_graph = graph_layer.v_proj(graph_embedding).view(graph_hidden_shape).transpose(1, 2)

            with torch.no_grad():   # freeze language tower
                q_language = self.language_layers[i].self_attn.q_proj(lang_embedding).view(lang_hidden_shape).transpose(1, 2)
                k_language = self.language_layers[i].self_attn.k_proj(lang_embedding).view(lang_hidden_shape).transpose(1, 2)
                v_language = self.language_layers[i].self_attn.v_proj(lang_embedding).view(lang_hidden_shape).transpose(1, 2)
                k_language = self.repeat_kv(k_language, q_language.shape[1] // k_language.shape[1])
                v_language = self.repeat_kv(v_language, q_language.shape[1] // v_language.shape[1])

            # global attention
            q = torch.cat((q_graph, q_language), dim=2)
            k = torch.cat((k_graph, k_language), dim=2)
            v = torch.cat((v_graph, v_language), dim=2)
            q, k = apply_rotary_pos_emb(q, k, cos, sin)
            attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
            attn_out = attn_out.reshape(attn_out.shape[0], attn_out.shape[2], -1).contiguous()
            
            # split by modality & output proj
            graph_attn = attn_out[:, :graph_seq_len]
            lang_attn = attn_out[:, graph_seq_len:graph_seq_len+lang_seq_len]
            graph_attn = graph_layer.o_proj(graph_attn)
            x_graph = x_graph + graph_attn
            graph_embedding = graph_layer.norm2(graph_attn)
            graph_embedding = graph_layer.ffn(graph_embedding)
            x_graph = x_graph + graph_embedding
            with torch.no_grad():   # freeze language tower
                lang_attn = self.language_layers[i].self_attn.o_proj(lang_attn)
                x_lang = x_lang + lang_attn
                lang_embedding = self.language_layers[i].post_attention_layernorm(lang_attn)
                lang_embedding = self.language_layers[i].mlp(lang_embedding)
                x_lang = x_lang + lang_embedding

        x_fused = torch.cat((x_graph, x_lang), dim=1)
        for i in range(len(self.graph_layers), len(self.language_layers)):
            x_fused = self.llama_layer_forward(i, x_fused)

        return x_fused

    def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
        """
        This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
        num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
        """
        batch, num_key_value_heads, slen, head_dim = hidden_states.shape
        if n_rep == 1:
            return hidden_states
        hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
        return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class RegressionDecoderModel(nn.Module):
    """
    support two training objective:
      1. autoregressive:  predict the next y_t
      2. diffusion:       predict noises or reconstruct x_0
    use_diffusion = True if using diffusion, or using AR
    """
    def __init__(
        self,
        hidden_dim: int,
        rope,                     # shared RoPE instance to save GPU memory
        nhead: int = 4,
        num_layers: int = 6,
        dropout: float = 0.0,
        use_diffusion: bool = False,
        max_len: int = 16384,
        time_embed_dim: int = 512,  # diffusion only
        node_parallel: bool = False,
    ):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.use_diffusion = use_diffusion
        self.node_parallel = node_parallel

        # 1. Decoder heap-up
        self.layers = nn.ModuleList(
            [MLADecoderLayer(hidden_dim, nhead, rope=rope, dropout=dropout)
             for _ in range(num_layers)]
        )
        self.gcn1 = GCNConv(hidden_dim, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)

        # 2. input/output projection
        self.input_proj = nn.Linear(1, hidden_dim)
        self.output_proj = nn.Linear(hidden_dim, 1)
        self.node_id_table = nn.Embedding(2000, hidden_dim)      # 2000 is the max_nodes, change if required
        # 3. timestep embedding (diffusion only, Sinusoidal -> MLP)
        if use_diffusion:
            self.time_embed = nn.Sequential(
                nn.Linear(time_embed_dim, hidden_dim * 4),
                nn.SiLU(),
                nn.Linear(hidden_dim * 4, hidden_dim)
            )
            # preserve sin/cos coding
            self.register_buffer(
                "timestep_freqs",
                self._build_sinusoidal(max_len, time_embed_dim),
                persistent=False,
            )

    def forward(
            self,
            tgt,  # (B, T, 1)
            memory=None,  # (B, M, H)  <- from Encoder tower
            memory_mask=None,
            diffusion_t=None,  # (B,) int64  <- diffusion timestep
            use_cache: bool=False,
            past_key_values=None,  # list[ layer_idx -> past ],
            pos_ids=None
    ):
        """
        return:
            prediction: (B, T, 1)
            new_past  : list[ layer_idx -> past ]   # only if use_cache=True
        """
        B, T, _ = tgt.shape
        if self.node_parallel:
            node_ids = torch.arange(B).cuda()
            node_id_emb = self.node_id_table(node_ids).unsqueeze(1)  # (B*G, 1, H)
            x = self.input_proj(tgt) + node_id_emb.repeat(1, tgt.size(1), 1)  # (B*G, T, H)
        else:
            x = self.input_proj(tgt)                   # (B, T, H)

        if pos_ids is None:
            pos_ids = torch.arange(T, device=x.device).unsqueeze(0).expand(B, -1)

        # ---------- addition: diffusion timestep embedding ----------
        if self.use_diffusion:
            assert diffusion_t is not None, "`diffusion_t` required when use_diffusion=True"
            t_emb = self._get_t_embed(diffusion_t)  # (B, H)
            x = x + t_emb[:, None, :]  # broadcast to seq_len dim

        # ---------- forward per layer ----------
        new_past = [] if use_cache else None
        causal = not self.use_diffusion              # disable causal mask if using diffusion objective
        for idx, layer in enumerate(self.layers):
            past = past_key_values[idx] if past_key_values is not None else None

            x, layer_past = layer(
                x,
                position_ids=pos_ids,
                memory=memory,
                memory_mask=memory_mask,
                past_key_value=past,
                use_cache=use_cache,
                causal=causal,
            )
            if use_cache:
                new_past.append(layer_past)

        prediction = self.output_proj(x)               # (B, T, 1)
        return prediction, new_past, x

    # ------------------------------------------------------------------ #
    # helpers
    # ------------------------------------------------------------------ #
    @staticmethod
    def _build_sinusoidal(max_t, dim):
        """preserve [0, max_t) -> dim sin/cos coding"""
        freqs = torch.arange(0, dim, 2, dtype=torch.float)
        freqs = 1.0 / (10000 ** (freqs / dim))
        t = torch.arange(max_t, dtype=torch.float).unsqueeze(1)  # (T, 1)
        emb = t * freqs  # (T, dim//2)
        emb = torch.cat([emb.sin(), emb.cos()], dim=1)  # (T, dim)
        return emb

    def _get_t_embed(self, t):
        # t: (B,)
        sinus = self.timestep_freqs[t]  # (B, time_embed_dim)
        return self.time_embed(sinus)  # (B, hidden_dim)


class DiffusionDecoderWrapper(nn.Module):
    def __init__(self, decoder: nn.Module, timesteps=1000):
        super().__init__()
        self.decoder = decoder
        self.timesteps = timesteps

        # Create beta schedule and precompute alphas
        self.register_buffer('betas', torch.linspace(1e-4, 0.02, timesteps))
        self.register_buffer('alphas', 1.0 - self.betas)
        self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))

    def forward(self, clean_target, memory):
        """
        clean_target: (B, T, 1)
        memory: encoder output (B, M, H)
        """
        clean_target = clean_target.unsqueeze(2)
        B, T, _ = clean_target.shape
        device = clean_target.device

        # Random timestep per sample
        t = torch.randint(0, self.timesteps, (B,), device=device).long()

        # Get alpha_bar_t for each sample
        alpha_bar_t = self.alpha_bars[t].view(B, 1, 1)

        # Sample noise
        noise = torch.randn_like(clean_target)

        # Diffuse: x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noise
        x_t = alpha_bar_t.sqrt() * clean_target + (1.0 - alpha_bar_t).sqrt() * noise

        # === call RegressionDecoderModel (diffusion mode) ===
        pred_noise, _ = self.decoder(
            tgt=x_t,
            memory=memory,
            diffusion_t=t,
            use_cache=False,
        )

        # Loss: predict noise
        loss = F.mse_loss(pred_noise, noise)
        return RegressionOutput(loss=loss, preds=pred_noise)

    @torch.no_grad()
    def generate(self, memory, seq_len: int, bus_num: int, steps: int = 1000):
        """
        DDPM sampling
        memory: (B, M, H)
        """
        B, _, _ = memory.shape
        device = memory.device
        x = torch.randn(B, bus_num * seq_len, 1, device=device)

        for t_inv in reversed(range(steps)):
            t = torch.full((B,), t_inv, dtype=torch.long, device=device)

            beta_t = self.betas[t_inv]
            alpha_t = self.alphas[t_inv]
            alpha_bar_t = self.alpha_bars[t_inv]

            pred_noise, _ = self.decoder(
                tgt=x,
                memory=memory,
                diffusion_t=t,
                use_cache=False
            )

            # Estimate x0
            coef1 = 1 / torch.sqrt(alpha_t)
            coef2 = (1 - alpha_t) / torch.sqrt(1 - alpha_bar_t)
            x0_pred = coef1 * (x - coef2 * pred_noise)

            # Add noise except last step
            if t_inv > 0:
                noise = torch.randn_like(x)
                x = (
                    torch.sqrt(self.alpha_bars[t_inv - 1]) * x0_pred +
                    torch.sqrt(self.betas[t_inv]) * noise
                )
            else:
                x = x0_pred

        return x

    @torch.no_grad()
    def generate_ddim(self, memory, seq_len: int, steps: int = 50, eta: float = 0.0):
        """
        DDIM sampling
        memory: (B, M, H)
        """
        B, _, _ = memory.shape
        device = memory.device

        total_T = self.timesteps
        times = torch.linspace(0, total_T - 1, steps).long().flip(0).to(device)
        alpha_bars = self.alpha_bars

        x = torch.randn(B, seq_len, 1, device=device)

        for i in range(len(times) - 1):
            t     = times[i]
            t_prev = times[i + 1]

            a_bar_t    = alpha_bars[t] + 1e-5
            a_bar_prev = alpha_bars[t_prev] + 1e-5

            t_batch = torch.full((B,), t, dtype=torch.long, device=device)
            pred_noise, _ = self.decoder(
                tgt=x,
                memory=memory,
                diffusion_t=t_batch,
                use_cache=False
            )

            x0_pred = (x - (1 - a_bar_t).sqrt() * pred_noise) / a_bar_t.sqrt()

            sigma = eta * ((1 - a_bar_prev) / (1 - a_bar_t) *
                           (1 - a_bar_t / a_bar_prev)).sqrt()
            noise = 0 if eta == 0 else torch.randn_like(x)

            x = a_bar_prev.sqrt() * x0_pred + (1 - a_bar_prev).sqrt() * pred_noise + sigma * noise

        return x


class RegressionDecoderWrapper(nn.Module):
    def __init__(self, decoder: "RegressionDecoderModel", node_parallel: bool=False):
        super().__init__()
        self.decoder = decoder
        self.node_parallel = node_parallel

    # ------------------------------------------------------------------ #
    # training (teacher forcing)
    # ------------------------------------------------------------------ #
    def forward(
        self,
        target_values: torch.Tensor,   # (B, G, L)   ground-truth sequence
        memory: torch.Tensor,          # (B, M, H)   from encoder tower
        edge_index,
    ) -> RegressionOutput:
        # prepend beginning token 0
        B, G, L = target_values.shape
        if self.node_parallel:
            seq = target_values.reshape(B * G, L, 1)
            z0 = torch.zeros(B * G, 1, 1, device=seq.device)
        else:
            seq = target_values.reshape(B, G * L, 1)
            z0 = torch.zeros(B, 1, 1, device=seq.device)
        tgt_in, tgt_ref = torch.cat([z0, seq[:, :-1]], dim=1), seq

        if self.node_parallel:
            next_tok = z0
            graph_mem = memory[:, :G].permute(1, 0, 2)
            language_mem = memory[:, G:].repeat(G, 1, 1)
            memory = torch.cat((graph_mem, language_mem), dim=1)
            preds = []
            past_key_values = None
            steps = seq.shape[1]
            pos_ids = torch.arange(steps, device=memory.device).unsqueeze(0).expand(B * G, -1)
            for i in range(steps):
                # only input "newest" token; previously produced KV has been cached
                y_hat, past_key_values, y_hidden = self.decoder(
                    next_tok,
                    memory=memory,
                    use_cache=True,
                    past_key_values=past_key_values,
                    pos_ids=pos_ids[:, i:i + 1]
                )
                y_hidden_res = y_hidden.permute(1, 0, 2)
                y_hidden = self.decoder.gcn1(y_hidden.permute(1, 0, 2), edge_index.squeeze(0))  # .permute(1, 0, 2)
                y_hidden = F.gelu(y_hidden) + y_hidden_res
                y_hidden = self.decoder.gcn2(y_hidden, edge_index.squeeze(0)).permute(1, 0, 2)
                memory = torch.cat((memory, y_hidden), dim=1)
                next_tok = y_hat[:, -1:, :]  # take the last token
                preds.append(next_tok)

            preds = torch.cat(preds, dim=1)
        else:
            preds, _, _ = self.decoder(
                # use_cache=False
                tgt_in,
                memory=memory,
                use_cache=False
            )

        loss = F.mse_loss(preds, tgt_ref)  # / target_values.shape[1] #(L * G)
        return RegressionOutput(loss=loss, preds=preds)

    # ------------------------------------------------------------------ #
    # autoregressive inference
    # ------------------------------------------------------------------ #
    @torch.no_grad()
    def generate(
        self,
        memory: torch.Tensor,          # (B, M, H)
        edge_index,
        steps: int,
        node_num: int,
        start_val: Optional[torch.Tensor] = None,  # (B, 1, 1) or None
    ) -> torch.Tensor:                 # -> (B, steps, 1)
        device = memory.device
        B, G, M = memory.shape[0], node_num, memory.shape[1]

        if self.node_parallel:
            graph_mem = memory[:, :node_num].permute(1, 0, 2)
            language_mem = memory[:, node_num:].repeat(node_num, 1, 1)
            memory = torch.cat((graph_mem, language_mem), dim=1)
            B = memory.size(0)
        else:
            steps = node_num * steps

        # beginning token
        next_tok = (torch.zeros(B, 1, 1, device=device)
                    if start_val is None else start_val.to(device))

        preds = []
        past_key_values = None
        pos_ids = torch.arange(steps, device=memory.device).unsqueeze(0).expand(B, -1)
        for i in range(steps):
            # only input "newest" token; previously produced KV has been cached
            y_hat, past_key_values, y_hidden = self.decoder(
                next_tok,
                memory=memory,
                use_cache=True,
                past_key_values=past_key_values,
                pos_ids=pos_ids[:, i:i + 1]
            )
            if self.node_parallel:
                y_hidden_res = y_hidden.permute(1, 0, 2)
                y_hidden = self.decoder.gcn1(y_hidden.permute(1, 0, 2), edge_index.squeeze(0))  # .permute(1, 0, 2)
                y_hidden = F.gelu(y_hidden) + y_hidden_res
                y_hidden = self.decoder.gcn2(y_hidden, edge_index.squeeze(0)).permute(1, 0, 2)
                memory = torch.cat((memory, y_hidden), dim=1)
            next_tok = y_hat[:, -1:, :]  # take the last token
            preds.append(next_tok)

        if self.node_parallel:
            return torch.cat(preds, dim=1).reshape(-1, node_num * steps, 1)  # (B, steps, 1)
        else:
            return torch.cat(preds, dim=1)


# Step 4: Final Decoder (Using LLaMA Decoder)
class GLA_Model_MoT(nn.Module):
    def __init__(self, graph_input_dim, graph_hidden_dim, in_channels=1, out_channels=32,
                 language_model_name='decapoda-research/llama-3.2-1b', float_disc=False, bins=2048, use_lora=False,
                 phase=1, use_diffusion=True, node_parallel=True):
        super(GLA_Model_MoT, self).__init__()
        HF_TOKEN = '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'
        self.graph_input_dim = graph_input_dim
        self.graph_hidden_dim = graph_hidden_dim
        self.nodal_encoder = NodalEncoder(in_channels, graph_input_dim)
        # self.nodal_encoder = NodalEncoder(in_channels, graph_input_dim, num_heads=4, num_layers=4)    # try use TransformerConv to include edge_attr
        self.graph_encoder = GraphEncoder(graph_input_dim*out_channels, graph_hidden_dim)

        self.use_lora = use_lora
        self.use_diffusion = use_diffusion

        peft_config = {
            "r": 16,
            "lora_alpha": 32,
            "lora_dropout": 0.05,
            "bias": "none",
            "task_type": "CAUSAL_LM",
            "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"],
        }
        peft_conf = LoraConfig(**peft_config)
        device_string = PartialState().process_index
        print(f'device_string={device_string}')
        self.lm_model = AutoModelForCausalLM.from_pretrained(
            language_model_name, token=HF_TOKEN, device_map={'': device_string},
            attn_implementation="eager" if 'gemma' in language_model_name else "flash_attention_2",
            torch_dtype=torch.bfloat16,
            # output_hidden_states=True,
        )
        if self.use_lora:
            self.lm_model = get_peft_model(self.lm_model, peft_conf)
        self.lm_tokenizer = AutoTokenizer.from_pretrained(language_model_name, token=HF_TOKEN)
        self.lm_tokenizer.pad_token = "[PAD]"  # Define a padding token
        self.lm_tokenizer.pad_token_id = self.lm_tokenizer.convert_tokens_to_ids("[PAD]")

        self.lm_config = AutoConfig.from_pretrained(language_model_name, token=HF_TOKEN)
        self.language_encoder = LanguageEncoder(self.lm_model, self.use_lora)
        self.rotary_emb = LlamaRotaryEmbedding(self.lm_model.config)

        self.graph_to_llm_proj = nn.Sequential(
            nn.Linear(graph_hidden_dim, self.lm_config.hidden_size),
            nn.SiLU(),
            nn.Linear(self.lm_config.hidden_size, self.lm_config.hidden_size),
            nn.LayerNorm(self.lm_config.hidden_size),
        )

        self.graph_transformer = nn.ModuleList(
            [MLAEncoderLayer(hidden_dim=self.lm_config.hidden_size, nhead=self.lm_config.num_attention_heads)
             for _ in range(self.lm_config.num_hidden_layers)]     # prev 4
        ).to(torch.bfloat16)
        self.global_attention = GlobalAttention(self.lm_model, self.graph_transformer, self.rotary_emb, self.lm_config.head_dim, self.use_lora)

        self.float_decoder = RegressionDecoderModel(
            self.lm_config.hidden_size, self.rotary_emb,
            nhead=self.lm_config.num_attention_heads, use_diffusion=use_diffusion,
            node_parallel=node_parallel
        )
        Wrapper = DiffusionDecoderWrapper if self.use_diffusion else RegressionDecoderWrapper
        self.float_decoder_wrapped = Wrapper(self.float_decoder, node_parallel)

        if phase == 1:
            # freeze lm parameters
            for param in self.lm_model.model.base_model.parameters():
                param.requires_grad = False  # prev false
            # fintune Graph Encoder
            for param in self.graph_encoder.parameters():
                param.requires_grad = True
            for param in self.nodal_encoder.parameters():
                param.requires_grad = True
            # finetune Projector
            for param in self.graph_to_llm_proj.parameters():
                param.requires_grad = True
            for param in self.graph_transformer.parameters():
                param.requires_grad = True
            for param in self.float_decoder.parameters():
                param.requires_grad = True
        elif phase == 2:
            # finetune lm parameters
            for param in self.lm_model.model.base_model.parameters():
                param.requires_grad = True
            # freeze Graph Encoder
            for param in self.graph_encoder.parameters():
                param.requires_grad = False
            for param in self.nodal_encoder.parameters():
                param.requires_grad = False
            # finetune Projector
            for param in self.graph_to_llm_proj.parameters():
                param.requires_grad = True
            for param in self.graph_transformer.parameters():
                param.requires_grad = True
            for param in self.float_decoder.parameters():
                param.requires_grad = True
        else:
            raise NotImplementedError

        self.generation_config = GenerationConfig(
            # penalty_alpha=0.6,
            do_sample=True,
            top_k=50,
            temperature=0.7,
            # repetition_penalty=1.2,
            max_new_tokens=1024,
            # pad_token_id=self.lm_tokenizer.eos_token_id,
            # use_cache=True,
            pad_token_id=self.lm_tokenizer.pad_token_id,
            eos_token_id=self.lm_tokenizer.eos_token_id,
            max_length=1024
        )


    def forward(self, graph_data, edge_index, instruction, target_response=None, return_graph_tokens=False):
        # Step 1: Graph Encoding
        x = []
        for i in range(graph_data.shape[1]):
            x.append(self.nodal_encoder(graph_data[:, i:i+1]).unsqueeze(0))
        x = torch.cat(x, dim=1)
        graph_embedding = self.graph_encoder(x, edge_index.squeeze(0))
        graph_tokens = self.graph_to_llm_proj(graph_embedding).to(torch.bfloat16)

        # Step 2: Language Encoding
        # get language instruction token length
        instruction = '[GRAPH] ' + instruction
        instruction_ids = self.lm_tokenizer(instruction, return_tensors="pt").input_ids
        instruction_token_len = instruction_ids.shape[1]

        # concat instruction and response
        if 'fault type and fault location' in instruction:
            if target_response is not None:
                instruction = instruction + target_response + self.lm_tokenizer.eos_token

        inputs = self.lm_tokenizer(instruction, return_tensors="pt")

        if target_response is not None:
            labels = torch.cat([-100 * torch.ones((1, graph_tokens.shape[1])), inputs.input_ids.clone()], dim=1).long()
            # mask out prompt's label, graph embedding and language
            labels[:, :graph_tokens.shape[1]+instruction_token_len] = -100
        else:
            labels = None

        # retrieve embedding
        language_embedding = self.language_encoder(inputs)

        # Step 4: Decoder Output
        if target_response is not None:
            attention_mask = torch.cat((torch.ones((1, graph_embedding.shape[1])).long(), inputs.attention_mask), dim=1).cuda()
            if 'fault type and fault location' in instruction:
                fused_embedding = torch.cat([graph_tokens, language_embedding], dim=1).to(torch.bfloat16)
                outputs = self.lm_model(inputs_embeds=fused_embedding, attention_mask=attention_mask, labels=labels)
            else:
                fused_embedding = self.global_attention(graph_tokens, language_embedding)
                outputs = self.float_decoder_wrapped(target_response.unsqueeze(0).float().cuda(), fused_embedding.float(), edge_index)
        else:
            attention_mask = torch.cat((torch.ones((1, graph_embedding.shape[1])).long(), inputs.attention_mask), dim=1).cuda()
            if 'fault type and fault location' in instruction:
                fused_embedding = torch.cat([graph_tokens, language_embedding], dim=1).to(torch.bfloat16)
                outputs = self.lm_model.generate(
                    inputs_embeds=fused_embedding, attention_mask=attention_mask,
                    generation_config=self.generation_config
                )
                return outputs
            else:
                fused_embedding = self.global_attention(graph_tokens, language_embedding)
                bus_num = int(re.findall(r'\d+', instruction.split('in ')[1].split(' bus system')[0])[0])
                pred_len = 0
                if 'power setpoint' in instruction or 'locational marginal price' in instruction:
                    pred_len = 1
                if 'real states' in instruction:
                    pred_len = 2
                if 'what are the predictions of' in instruction:
                    pred_len = int(re.findall(r'\d+', instruction.split('following')[1].split('steps of')[0])[0])

                if pred_len == 0:
                    import ipdb
                    ipdb.set_trace()
                assert pred_len > 0
                outputs = self.float_decoder_wrapped.generate(fused_embedding.float(), edge_index, pred_len, bus_num)
                # outputs = self.float_decoder_wrapped.generate_ddim(fused_embedding, pred_len)

        if return_graph_tokens:
            return outputs, graph_tokens
        return outputs

    def decode(self, outputs):
        output_text = self.lm_tokenizer.decode(outputs[0], skip_special_tokens=False)
        return output_text

    # save model and tokenizer
    def save(self, model_path, tokenizer_path):
        # save model weights
        torch.save(self.state_dict(), f'saved_models/{model_path}')
        print(f"Model saved at {model_path}")

        # save tokenizer
        self.lm_tokenizer.save_pretrained(f'saved_models/{tokenizer_path}')
        print(f"Tokenizer saved at {tokenizer_path}")

    # load model and tokenizer
    def load(self, model_path, tokenizer_path):
        # load tokenizer
        self.lm_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        print(f"Tokenizer loaded from {tokenizer_path}")

        # load model weights
        self.load_state_dict(torch.load(model_path, weights_only=True))
        print(f"Model loaded from {model_path}")
