import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, GenerationConfig
from torch_geometric.nn import GCNConv, TransformerConv
from torch_geometric.data import Data
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator, PartialState


class NodalEncoder(nn.Module):
    def __init__(self, input_dim, output_dim, num_layers=2, hidden_dim=32, kernel_size=3):
        super(NodalEncoder, self).__init__()

        self.layers = nn.ModuleList()

        for i in range(num_layers):
            in_channels = input_dim if i == 0 else hidden_dim
            out_channels = hidden_dim
            conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=1, bias=False)
            self.layers.append(conv)

        self.adaptive_pool = nn.AdaptiveAvgPool1d(output_dim)
        self.layer_norm = nn.LayerNorm(output_dim)

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x_res = x
            x = F.gelu(layer(x))
            if i > 0:
                x = x + x_res
        x = self.adaptive_pool(x)
        x = self.layer_norm(x)
        x = x.view(x.size(0), -1)
        return x


# Step 1: Graph Encoder (Using GCN)
class GraphEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=3):
        super(GraphEncoder, self).__init__()

        self.layers = nn.ModuleList()

        for i in range(num_layers):
            in_channels = input_dim if i == 0 else hidden_dim
            out_channels = hidden_dim
            conv = GCNConv(in_channels, out_channels)
            self.layers.append(conv)

    def forward(self, x, edge_index):
        for i, layer in enumerate(self.layers):
            x_res = x
            x = F.gelu(layer(x, edge_index))
            if i > 0:
                x = x + x_res
        return x


# Step 2: Language Encoder (Using LLaMA)
class LanguageEncoder(nn.Module):
    def __init__(self, model, use_lora=False):
        super(LanguageEncoder, self).__init__()
        self.model = model
        self.use_lora = use_lora

    def forward(self, inputs):
        outputs = self.model.get_input_embeddings()(inputs.input_ids.long().cuda())
        return outputs

# Step 4: Final Decoder (Using LLaMA Decoder)
class GLA_Model(nn.Module):
    def __init__(self, graph_input_dim, graph_hidden_dim, in_channels=1, out_channels=32,
                 language_model_name='decapoda-research/llama-3.2-1b', float_disc=False, bins=2048, use_lora=False, phase=1):
        super(GLA_Model, self).__init__()
        HF_TOKEN = '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'
        self.graph_input_dim = graph_input_dim
        self.graph_hidden_dim = graph_hidden_dim
        self.nodal_encoder = NodalEncoder(in_channels, graph_input_dim)
        # self.nodal_encoder = NodalEncoder(in_channels, graph_input_dim, num_heads=4, num_layers=4)    # try use TransformerConv to include edge_attr
        self.graph_encoder = GraphEncoder(graph_input_dim*out_channels, graph_hidden_dim)

        self.use_lora = use_lora

        peft_config = {
            "r": 16,
            "lora_alpha": 32,
            "lora_dropout": 0.05,
            "bias": "none",
            "task_type": "CAUSAL_LM",
            "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"],
        }
        peft_conf = LoraConfig(**peft_config)
        device_string = PartialState().process_index
        print(f'device_string={device_string}')
        self.lm_model = AutoModelForCausalLM.from_pretrained(
            language_model_name, token=HF_TOKEN, device_map={'': device_string},
            # attn_implementation="eager" if 'gemma' in language_model_name else "flash_attention_2",
            attn_implementation="flash_attention_2",
            torch_dtype=torch.bfloat16
        )
        if self.use_lora:
            self.lm_model = get_peft_model(self.lm_model, peft_conf)
        self.lm_tokenizer = AutoTokenizer.from_pretrained(language_model_name, token=HF_TOKEN)
        self.lm_tokenizer.pad_token = "[PAD]"  # Define a padding token
        self.lm_tokenizer.pad_token_id = self.lm_tokenizer.convert_tokens_to_ids("[PAD]")

        self.lm_config = AutoConfig.from_pretrained(language_model_name, token=HF_TOKEN)
        self.language_encoder = LanguageEncoder(self.lm_model, self.use_lora)

        self.graph_to_llm_proj = nn.Sequential(
            nn.Linear(graph_hidden_dim, self.lm_config.hidden_size),
            nn.GELU(),
            nn.Linear(self.lm_config.hidden_size, self.lm_config.hidden_size),
            nn.LayerNorm(self.lm_config.hidden_size),
        )

        if float_disc:
            action_tokens = [f'{i}' for i in range(bins+1)]
            num_added_tokens = self.lm_tokenizer.add_tokens(action_tokens)
            self.lm_model.resize_token_embeddings(len(self.lm_tokenizer))
            print(f'Added {num_added_tokens} new tokens to the tokenizer.')

        if phase == 1:
            # freeze lm parameters
            for param in self.lm_model.model.base_model.parameters():
                param.requires_grad = False
            # fintune Graph Encoder
            for param in self.graph_encoder.parameters():
                param.requires_grad = True
            for param in self.nodal_encoder.parameters():
                param.requires_grad = True
            # finetune Projector
            for param in self.graph_to_llm_proj.parameters():
                param.requires_grad = True
        elif phase == 2:
            # finetune lm parameters
            for param in self.lm_model.model.base_model.parameters():
                param.requires_grad = True
            # freeze Graph Encoder
            for param in self.graph_encoder.parameters():
                param.requires_grad = False
            for param in self.nodal_encoder.parameters():
                param.requires_grad = False
            # finetune Projector
            for param in self.graph_to_llm_proj.parameters():
                param.requires_grad = True
        else:
            raise NotImplementedError

        self.generation_config = GenerationConfig(
            # penalty_alpha=0.6,
            do_sample=False,
            # top_k=50,
            # temperature=0.7,
            # repetition_penalty=1.2,
            max_new_tokens=20000,
            # pad_token_id=self.lm_tokenizer.eos_token_id,
            # use_cache=True,
            pad_token_id=self.lm_tokenizer.pad_token_id,
            eos_token_id=self.lm_tokenizer.eos_token_id,
            max_length=20000
        )

    def forward(self, graph_data, edge_index, instruction, target_response=None):
        # Step 1: Graph Encoding
        x = []
        for i in range(graph_data.shape[1]):
            x.append(self.nodal_encoder(graph_data[:, i:i+1]).unsqueeze(0))
        x = torch.cat(x, dim=1)
        graph_embedding = self.graph_encoder(x, edge_index.squeeze(0))
        graph_tokens = self.graph_to_llm_proj(graph_embedding)

        # Step 2: Language Encoding
        # get language instruction token length
        instruction = '[GRAPH] ' + instruction
        instruction_ids = self.lm_tokenizer(instruction, return_tensors="pt").input_ids
        instruction_token_len = instruction_ids.shape[1]

        # concat instruction and response
        if target_response is not None:
            instruction = instruction + target_response + self.lm_tokenizer.eos_token

        inputs = self.lm_tokenizer(instruction, return_tensors="pt")

        if target_response is not None:
            labels = torch.cat([-100 * torch.ones((1, graph_tokens.shape[1])), inputs.input_ids.clone()], dim=1).long()
            # mask out prompt's label, graph embedding and language
            labels[:, :graph_tokens.shape[1]+instruction_token_len] = -100
        else:
            labels = None

        # retrieve embedding
        language_embedding = self.language_encoder(inputs)

        # Step 3: Fusion of Graph and Language Features
        fused_embedding = torch.cat([graph_tokens, language_embedding], dim=1).to(torch.bfloat16)

        # Step 4: Decoder Output
        if target_response is not None:
            attention_mask = torch.cat((torch.ones((1, graph_embedding.shape[1])).long(), inputs.attention_mask), dim=1).cuda()
            outputs = self.lm_model(inputs_embeds=fused_embedding, attention_mask=attention_mask, labels=labels)
        else:
            attention_mask = torch.cat((torch.ones((1, graph_embedding.shape[1])).long(), inputs.attention_mask), dim=1).cuda()
            outputs = self.lm_model.generate(inputs_embeds=fused_embedding, attention_mask=attention_mask, #pad_token_id=self.lm_tokenizer.pad_token_id,
                                                generation_config=self.generation_config
                                                )

        return outputs

    def decode(self, outputs):
        output_text = self.lm_tokenizer.decode(outputs[0], skip_special_tokens=False)
        return output_text

    # save model and tokenizer
    def save(self, model_path, tokenizer_path):
        # save model weights
        torch.save(self.state_dict(), f'saved_models/{model_path}')
        print(f"Model saved at {model_path}")

        # save tokenizer
        self.lm_tokenizer.save_pretrained(f'saved_models/{tokenizer_path}')
        print(f"Tokenizer saved at {tokenizer_path}")

    # load model and tokenizer
    def load(self, model_path, tokenizer_path):
        # load tokenizer
        self.lm_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        print(f"Tokenizer loaded from {tokenizer_path}")

        # load model weights
        self.load_state_dict(torch.load(model_path, weights_only=True))
        print(f"Model loaded from {model_path}")


# unit test
if __name__ == '__main__':
    # Sample graph (2 nodes, 1 edge)
    import pickle
    from pypower.idx_bus import *
    from pypower.idx_gen import *
    from pypower.idx_brch import *
    from pypower.idx_cost import *
    import numpy as np
    root_path = ""  # dataset storage path
    with open(f"{root_path}/ppc_lst_IEEE14.pkl", "rb") as file:
        ppc_lst = pickle.load(file)
    for i, ppc in enumerate(ppc_lst):
        edge_index = torch.tensor([[ppc['branch'][j, F_BUS], ppc['branch'][j, T_BUS]] for j in range(ppc['num_bus'])]).long().cuda().permute(1, 0)
        pg_qg = np.zeros((ppc['num_bus'], 2))
        pg_qg[ppc['gen'][:, GEN_BUS].astype(int).tolist()] = ppc['gen'][:, [PG, QG]]
        x = torch.tensor(np.concatenate((ppc['bus'][:, [PD, QD, VM, VA]], pg_qg), axis=1)).unsqueeze(0).float().cuda()
        instruction = "What is the best active power setpoint of generators?"
        model = GLA_Model(graph_input_dim=128, graph_hidden_dim=1024, language_model_name='meta-llama/Llama-3.2-1B').cuda()
        output = model(x, edge_index, instruction)
