import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, GenerationConfig
from torch_geometric.nn import GCNConv, TransformerConv
from torch_geometric.data import Data
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator, PartialState
from utils import extract_floats, RegressionOutput, MLADecoderLayer, ScaledLlamaRotaryEmbedding, edge_index_to_adj_list
from typing import Optional
import re
import math
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from typing import Optional, List, Tuple, Union


class NodalEncoder(nn.Module):
    def __init__(self, input_dim, output_dim, num_layers=2, hidden_dim=32, kernel_size=3):
        super(NodalEncoder, self).__init__()

        self.layers = nn.ModuleList()

        for i in range(num_layers):
            in_channels = input_dim if i == 0 else hidden_dim
            out_channels = hidden_dim
            conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=1, bias=False)
            self.layers.append(conv)

        self.adaptive_pool = nn.AdaptiveAvgPool1d(output_dim)
        self.layer_norm = nn.LayerNorm(output_dim)

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x_res = x
            x = F.gelu(layer(x))
            if i > 0:
                x = x + x_res
        x = self.adaptive_pool(x)
        x = self.layer_norm(x)
        x = x.view(x.size(0), -1)
        return x


# Step 1: Graph Encoder (Using GCN)
class GraphEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=3):
        super(GraphEncoder, self).__init__()

        self.layers = nn.ModuleList()

        for i in range(num_layers):
            in_channels = input_dim if i == 0 else hidden_dim
            out_channels = hidden_dim
            conv = GCNConv(in_channels, out_channels)
            self.layers.append(conv)

    def forward(self, x, edge_index):
        for i, layer in enumerate(self.layers):
            x_res = x
            x = F.gelu(layer(x, edge_index))
            if i > 0:
                x = x + x_res
        return x


# Step 2: Language Encoder (Using LLaMA)
class LanguageEncoder(nn.Module):
    def __init__(self, model, use_lora=False):
        super(LanguageEncoder, self).__init__()
        self.model = model
        self.use_lora = use_lora

    def forward(self, inputs):
        outputs = self.model.get_input_embeddings()(inputs.input_ids.long().cuda())
        return outputs


class RegressionDecoderModel(nn.Module):
    """
    support two training objective:
      1. autoregressive:  predict the next y_t
      2. diffusion:       predict noises or reconstruct x_0
    use_diffusion = True if using diffusion, or using AR
    """
    def __init__(
        self,
        hidden_dim: int,
        rope,                     # shared RoPE instance to save GPU memory
        nhead: int = 4,
        num_layers: int = 2,
        dropout: float = 0.0,
        use_diffusion: bool = False,
        max_len: int = 16384,
        time_embed_dim: int = 512,  # diffusion only
        node_parallel: bool = False,
    ):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.use_diffusion = use_diffusion
        self.node_parallel = node_parallel

        # 1. Decoder heap-up
        self.layers = nn.ModuleList(
            [MLADecoderLayer(hidden_dim, nhead, rope=rope, dropout=dropout)
             for _ in range(num_layers)]
        )
        self.gcn1 = GCNConv(hidden_dim, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)

        # 2. input/output projection
        self.input_proj = nn.Linear(1, hidden_dim)
        self.output_proj = nn.Linear(hidden_dim, 1)
        self.node_id_table = nn.Embedding(2000, hidden_dim)    # 2000 is the max_nodes, change if required

        # 3. timestep embedding (diffusion only, Sinusoidal -> MLP)
        if use_diffusion:
            self.time_embed = nn.Sequential(
                nn.Linear(time_embed_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, hidden_dim)
            )
            # preserve sin/cos coding
            self.register_buffer(
                "timestep_freqs",
                self._build_sinusoidal(max_len, time_embed_dim),
                persistent=False,
            )

    def forward(
            self,
            tgt,  # (B, T, 1)
            memory=None,  # (B, M, H)  <- from Encoder tower
            memory_mask=None,
            diffusion_t=None,  # (B,) int64  <- diffusion timestep
            use_cache: bool = False,
            past_key_values=None,  # list[ layer_idx -> past ],
            pos_ids=None,
            tgt_hidden=None,
    ):
        """
        return:
            prediction: (B, T, 1)
            new_past  : list[ layer_idx -> past ]   # only if use_cache=True
        """
        B, T, _ = tgt.shape
        if tgt_hidden is None:
            if self.node_parallel:
                node_ids = torch.arange(B).cuda()
                node_id_emb = self.node_id_table(node_ids).unsqueeze(1)  # (B*G, 1, H)
                x = self.input_proj(tgt) + node_id_emb.expand(-1, tgt.size(1), -1)  # (B*G, T, H)
            else:
                x = self.input_proj(tgt)                   # (B, T, H)
        else:
            x = tgt_hidden

        if pos_ids is None:
            pos_ids = torch.arange(T, device=x.device).unsqueeze(0).expand(B, -1)

        # ---------- addition: diffusion timestep embedding ----------
        if self.use_diffusion:
            assert diffusion_t is not None, "`diffusion_t` required when use_diffusion=True"
            t_emb = self._get_t_embed(diffusion_t)  # (B, H)
            x = x + t_emb[:, None, :]  # broadcast to seq_len dim

        # ---------- forward per layer ----------
        new_past = [] if use_cache else None
        causal = not self.use_diffusion              # disable causal mask if using diffusion objective
        for idx, layer in enumerate(self.layers):
            past = past_key_values[idx] if past_key_values is not None else None

            x, layer_past = layer(
                x,
                position_ids=pos_ids,
                memory=memory,
                memory_mask=memory_mask,
                past_key_value=past,
                use_cache=use_cache,
                causal=causal,
            )
            if use_cache:
                new_past.append(layer_past)
        # x = self.longnet(x)

        prediction = self.output_proj(x)               # (B, T, 1)
        return prediction, new_past, x

    # ------------------------------------------------------------------ #
    # helpers
    # ------------------------------------------------------------------ #
    @staticmethod
    def _build_sinusoidal(max_t, dim):
        """preserve [0, max_t) -> dim sin/cos coding"""
        freqs = torch.arange(0, dim, 2, dtype=torch.float)
        freqs = 1.0 / (10000 ** (freqs / dim))
        t = torch.arange(max_t, dtype=torch.float).unsqueeze(1)  # (T, 1)
        emb = t * freqs  # (T, dim//2)
        emb = torch.cat([emb.sin(), emb.cos()], dim=1)  # (T, dim)
        return emb

    def _get_t_embed(self, t):
        # t: (B,)
        sinus = self.timestep_freqs[t]  # (B, time_embed_dim)
        return self.time_embed(sinus)  # (B, hidden_dim)


class RegressionDecoderWrapper(nn.Module):
    def __init__(self, decoder: "RegressionDecoderModel", node_parallel: bool=False):
        super().__init__()
        self.decoder = decoder
        self.node_parallel = node_parallel

    # ------------------------------------------------------------------ #
    # training (teacher forcing)
    # ------------------------------------------------------------------ #
    def forward(
        self,
        target_values: torch.Tensor,   # (B, G, L)   ground-truth sequence
        memory: torch.Tensor,          # (B, M, H)   from encoder tower
        edge_index,
    ) -> RegressionOutput:
        # prepend beginning token 0                  # (B, T, 1)  -> target

        B, G, L = target_values.shape
        if self.node_parallel:
            seq = target_values.reshape(B * G, L, 1)
            z0 = torch.zeros(B * G, 1, 1, device=seq.device)
        else:
            seq = target_values.reshape(B, G * L, 1)
            z0 = torch.zeros(B, 1, 1, device=seq.device)
        tgt_in, tgt_ref = torch.cat([z0, seq[:, :-1]], dim=1), seq

        if self.node_parallel:
            next_tok = z0
            graph_mem = memory[:, :G].permute(1, 0, 2)
            language_mem = memory[:, G:].expand(G, -1, -1)
            memory = torch.cat((graph_mem, language_mem), dim=1)
            preds = []
            past_key_values = None
            steps = seq.shape[1]
            pos_ids = torch.arange(steps, device=memory.device).unsqueeze(0).expand(B*G, -1)
            for i in range(steps):
                # only input "newest" token; previously produced KV has been cached
                y_hat, past_key_values, y_hidden = self.decoder(
                    next_tok,
                    memory=memory,
                    # memory_mask=mask,
                    use_cache=True,
                    past_key_values=past_key_values,
                    pos_ids=pos_ids[:, i:i + 1],
                    tgt_hidden=None if i==0 else y_hidden,
                )
                y_hidden_res = y_hidden.permute(1, 0, 2)
                y_hidden = self.decoder.gcn1(y_hidden.permute(1, 0, 2), edge_index.squeeze(0))#.permute(1, 0, 2)
                y_hidden = F.gelu(y_hidden) + y_hidden_res
                y_hidden = self.decoder.gcn2(y_hidden, edge_index.squeeze(0)).permute(1, 0, 2)
                # memory = torch.cat((memory, y_hidden), dim=1)
                next_tok = y_hat[:, -1:, :]  # take the last token
                preds.append(next_tok)

            preds = torch.cat(preds, dim=1)
        else:
            preds, _, _ = self.decoder(                       # use_cache=False
                tgt_in,
                memory=memory,
                # memory_mask=mask,
                use_cache=False
            )

        loss = F.mse_loss(preds, tgt_ref) / target_values.shape[1] #(L * G)
        return RegressionOutput(loss=loss, preds=preds)

    # ------------------------------------------------------------------ #
    # autoregressive inference
    # ------------------------------------------------------------------ #
    @torch.no_grad()
    def generate(
        self,
        memory: torch.Tensor,          # (B, M, H)
        edge_index,
        steps: int,
        node_num: int,
        start_val: Optional[torch.Tensor] = None,  # (B, 1, 1) or None
    ) -> torch.Tensor:                 # -> (B, steps, 1)
        device = memory.device

        B, G, M = memory.shape[0], node_num, memory.shape[1]
        # mask = torch.zeros(B * G, M, dtype=torch.bool, device=memory.device)
        # neighbors = edge_index_to_adj_list(edge_index, num_nodes=G)
        # for i in range(G):
        #     idx = neighbors[i]
        #     mask[i::G, list(idx)] = True
        # mask[:, G:M] = True

        if self.node_parallel:
            graph_mem = memory[:, :node_num].permute(1, 0, 2)
            language_mem = memory[:, node_num:].expand(node_num, -1, -1)
            memory = torch.cat((graph_mem, language_mem), dim=1)
            # memory = memory.expand(G, -1, -1)
            B = memory.size(0)
        else:
            steps = node_num * steps

        # beginning token
        next_tok = (torch.zeros(B, 1, 1, device=device)
                    if start_val is None else start_val.to(device))

        preds = []
        past_key_values = None
        pos_ids = torch.arange(steps, device=memory.device).unsqueeze(0).expand(B, -1)
        for i in range(steps):
            # only input "newest" token; previously produced KV has been cached
            y_hat, past_key_values, y_hidden = self.decoder(
                next_tok,
                memory=memory,
                # memory_mask=mask,
                use_cache=True,
                past_key_values=past_key_values,
                pos_ids=pos_ids[:, i:i+1],
                tgt_hidden=None if i==0 else y_hidden
            )
            next_tok = y_hat[:, -1:, :]  # take the last token
            if self.node_parallel:
                y_hidden_res = y_hidden.permute(1, 0, 2)
                y_hidden = self.decoder.gcn1(y_hidden.permute(1, 0, 2), edge_index.squeeze(0))  # .permute(1, 0, 2)
                y_hidden = F.gelu(y_hidden) + y_hidden_res
                y_hidden = self.decoder.gcn2(y_hidden, edge_index.squeeze(0)).permute(1, 0, 2)
                # memory = torch.cat((memory, y_hidden), dim=1)
                # next_tok = next_tok + y_hidden

            preds.append(next_tok)

        if self.node_parallel:
            return torch.cat(preds, dim=1).reshape(-1, node_num*steps, 1)   # (B, steps, 1)
        else:
            return torch.cat(preds, dim=1)


# Step 4: Final Decoder (Using LLaMA Decoder)
class GLA_Model_v2(nn.Module):
    def __init__(self, graph_input_dim, graph_hidden_dim, in_channels=1, out_channels=32,
                 language_model_name='decapoda-research/llama-3.2-1b', float_disc=False, bins=2048, use_lora=False,
                 phase=1, use_diffusion=True, node_parallel=True):
        super(GLA_Model_v2, self).__init__()
        HF_TOKEN = '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'
        self.graph_input_dim = graph_input_dim
        self.graph_hidden_dim = graph_hidden_dim
        self.nodal_encoder = NodalEncoder(in_channels, graph_input_dim)
        # self.nodal_encoder = NodalEncoder(in_channels, graph_input_dim, num_heads=4, num_layers=4)    # try use TransformerConv to include edge_attr
        self.graph_encoder = GraphEncoder(graph_input_dim*out_channels, graph_hidden_dim)

        self.use_lora = use_lora
        self.use_diffusion = use_diffusion

        peft_config = {
            "r": 16,
            "lora_alpha": 32,
            "lora_dropout": 0.05,
            "bias": "none",
            "task_type": "CAUSAL_LM",
            "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"],
        }
        peft_conf = LoraConfig(**peft_config)
        device_string = PartialState().process_index
        print(f'device_string={device_string}')
        self.lm_model = AutoModelForCausalLM.from_pretrained(
            language_model_name, token=HF_TOKEN, device_map={'': device_string},
            attn_implementation="eager" if 'gemma' in language_model_name else "flash_attention_2",
            torch_dtype=torch.bfloat16
        )
        if self.use_lora:
            self.lm_model = get_peft_model(self.lm_model, peft_conf)
        self.lm_tokenizer = AutoTokenizer.from_pretrained(language_model_name, token=HF_TOKEN)
        self.lm_tokenizer.pad_token = "[PAD]"  # Define a padding token
        self.lm_tokenizer.pad_token_id = self.lm_tokenizer.convert_tokens_to_ids("[PAD]")

        self.lm_config = AutoConfig.from_pretrained(language_model_name, token=HF_TOKEN)
        self.language_encoder = LanguageEncoder(self.lm_model, self.use_lora)
        self.rotary_emb = ScaledLlamaRotaryEmbedding(self.lm_config)

        self.graph_to_llm_proj = nn.Sequential(
            nn.Linear(graph_hidden_dim, self.lm_config.hidden_size),
            nn.GELU(),
            nn.Linear(self.lm_config.hidden_size, self.lm_config.hidden_size),
            nn.LayerNorm(self.lm_config.hidden_size),
        )

        self.float_decoder = RegressionDecoderModel(
            self.lm_config.hidden_size, self.rotary_emb,
            nhead=self.lm_config.num_attention_heads, use_diffusion=use_diffusion, #num_layers=self.lm_config.num_hidden_layers
            node_parallel=node_parallel
        )
        if self.use_diffusion:
            self.float_decoder_wrapped = DiffusionDecoderWrapper(self.float_decoder)
        else:
            self.float_decoder_wrapped = RegressionDecoderWrapper(self.float_decoder, node_parallel)


        if phase == 1:
            # freeze lm parameters
            for param in self.lm_model.model.base_model.parameters():
                param.requires_grad = False  # prev False
            # fintune Graph Encoder
            for param in self.graph_encoder.parameters():
                param.requires_grad = True
            for param in self.nodal_encoder.parameters():
                param.requires_grad = True
            # finetune Projector
            for param in self.graph_to_llm_proj.parameters():
                param.requires_grad = True
            for param in self.float_decoder.parameters():
                param.requires_grad = True
        elif phase == 2:
            # finetune lm parameters
            for param in self.lm_model.model.base_model.parameters():
                param.requires_grad = True
            # freeze Graph Encoder
            for param in self.graph_encoder.parameters():
                param.requires_grad = False
            for param in self.nodal_encoder.parameters():
                param.requires_grad = False
            # finetune Projector
            for param in self.graph_to_llm_proj.parameters():
                param.requires_grad = True
            for param in self.float_decoder.parameters():
                param.requires_grad = True
        else:
            raise NotImplementedError

        self.generation_config = GenerationConfig(
            # penalty_alpha=0.6,
            do_sample=True,
            top_k=50,
            temperature=0.7,
            # repetition_penalty=1.2,
            max_new_tokens=1024,
            # pad_token_id=self.lm_tokenizer.eos_token_id,
            # use_cache=True,
            pad_token_id=self.lm_tokenizer.pad_token_id,
            eos_token_id=self.lm_tokenizer.eos_token_id,
            max_length=1024
        )

    def forward(self, graph_data, edge_index, instruction, target_response=None):
        # Step 1: Graph Encoding
        x = []
        for i in range(graph_data.shape[1]):
            x.append(self.nodal_encoder(graph_data[:, i:i+1]).unsqueeze(0))
        x = torch.cat(x, dim=1)
        graph_embedding = self.graph_encoder(x, edge_index.squeeze(0))
        graph_tokens = self.graph_to_llm_proj(graph_embedding)

        # Step 2: Language Encoding
        # get language instruction token length
        instruction = '[GRAPH] ' + instruction
        instruction_ids = self.lm_tokenizer(instruction, return_tensors="pt").input_ids
        instruction_token_len = instruction_ids.shape[1]

        # concat instruction and response
        if 'fault type and fault location' in instruction:
            if target_response is not None:
                instruction = instruction + target_response + self.lm_tokenizer.eos_token

        inputs = self.lm_tokenizer(instruction, return_tensors="pt")

        if target_response is not None:
            labels = torch.cat([-100 * torch.ones((1, graph_tokens.shape[1])), inputs.input_ids.clone()], dim=1).long()
            # mask out prompt's label, graph embedding and language
            labels[:, :graph_tokens.shape[1]+instruction_token_len] = -100
        else:
            labels = None

        # retrieve embedding
        language_embedding = self.language_encoder(inputs)

        # Step 3: Fusion of Graph and Language Features
        fused_embedding = torch.cat([graph_tokens, language_embedding], dim=1).to(torch.bfloat16)

        # Step 4: Decoder Output
        if target_response is not None:
            attention_mask = torch.cat((torch.ones((1, graph_embedding.shape[1])).long(), inputs.attention_mask), dim=1).cuda()
            if 'fault type and fault location' in instruction:
                outputs = self.lm_model(inputs_embeds=fused_embedding, attention_mask=attention_mask, labels=labels)
            else:
                mem = self.lm_model(inputs_embeds=fused_embedding, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1]
                outputs = self.float_decoder_wrapped(target_response.unsqueeze(0).float().cuda(), mem.float(), edge_index)
        else:
            attention_mask = torch.cat((torch.ones((1, graph_embedding.shape[1])).long(), inputs.attention_mask), dim=1).cuda()
            if 'fault type and fault location' in instruction:
                outputs = self.lm_model.generate(inputs_embeds=fused_embedding, attention_mask=attention_mask, #pad_token_id=self.lm_tokenizer.pad_token_id,
                                                    generation_config=self.generation_config
                                                    )
            else:
                bus_num = int(re.findall(r'\d+', instruction.split('in ')[1].split(' bus system')[0])[0])
                pred_len = 0
                if 'power setpoint' in instruction or 'locational marginal price' in instruction:
                    pred_len = 1
                if 'real states' in instruction:
                    pred_len = 2
                if 'what are the predictions of' in instruction:
                    pred_len = int(re.findall(r'\d+', instruction.split('following')[1].split('steps of')[0])[0])

                assert pred_len > 0
                mem = self.lm_model(inputs_embeds=fused_embedding, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1]
                # if self.use_diffusion:
                outputs = self.float_decoder_wrapped.generate(mem.float(), edge_index, pred_len, bus_num)
                # else:
                #     if pred_len * bus_num > 4096:
                #         outputs = self.float_decoder_wrapped.generate_chunked(mem.float(), pred_len, bus_num)
                #     else:
                #         outputs = self.float_decoder_wrapped.generate(mem.float(), pred_len, bus_num)

        return outputs

    def decode(self, outputs):
        output_text = self.lm_tokenizer.decode(outputs[0], skip_special_tokens=False)
        return output_text

    # save model and tokenizer
    def save(self, model_path, tokenizer_path):
        # save model weights
        torch.save(self.state_dict(), f'saved_models/{model_path}')
        print(f"Model saved at {model_path}")

        # save tokenizer
        self.lm_tokenizer.save_pretrained(f'saved_models/{tokenizer_path}')
        print(f"Tokenizer saved at {tokenizer_path}")

    # load model and tokenizer
    def load(self, model_path, tokenizer_path):
        # load tokenizer
        self.lm_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        print(f"Tokenizer loaded from {tokenizer_path}")

        # load model weights
        self.load_state_dict(torch.load(model_path, weights_only=True))
        print(f"Model loaded from {model_path}")
