import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, GenerationConfig
from torch_geometric.nn import GCNConv, TransformerConv
from torch_geometric.data import Data
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator, PartialState

# Baseline 2 pure LLM
class LLM_Model(nn.Module):
    def __init__(self, language_model_name='decapoda-research/llama-3.2-1b', use_lora=False, bins=2048,):
        super(LLM_Model, self).__init__()
        HF_TOKEN = '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'
        self.use_lora = use_lora

        peft_config = {
            "r": 16,
            "lora_alpha": 32,
            "lora_dropout": 0.05,
            "bias": "none",
            "task_type": "CAUSAL_LM",
            "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"],
        }
        peft_conf = LoraConfig(**peft_config)
        device_string = PartialState().process_index
        print(f'device_string={device_string}')
        self.lm_model = AutoModelForCausalLM.from_pretrained(
            language_model_name, token=HF_TOKEN, device_map={'': device_string},
            attn_implementation="eager" if 'gemma' in language_model_name else "flash_attention_2",
            torch_dtype=torch.bfloat16
        )
        self.lm_model.gradient_checkpointing_enable()
        if self.use_lora:
            self.lm_model = get_peft_model(self.lm_model, peft_conf)
        self.lm_tokenizer = AutoTokenizer.from_pretrained(language_model_name, token=HF_TOKEN)
        self.lm_tokenizer.pad_token = "[PAD]"  # Define a padding token
        self.lm_tokenizer.pad_token_id = self.lm_tokenizer.convert_tokens_to_ids("[PAD]")

        self.lm_config = AutoConfig.from_pretrained(language_model_name, token=HF_TOKEN)

        action_tokens = [f'{i}' for i in range(bins+1)]
        num_added_tokens = self.lm_tokenizer.add_tokens(action_tokens)
        self.lm_model.resize_token_embeddings(len(self.lm_tokenizer))
        print(f'Added {num_added_tokens} new tokens to the tokenizer.')

        self.generation_config = GenerationConfig(
            do_sample=True,
            top_k=50,
            temperature=0.7,
            pad_token_id=self.lm_tokenizer.pad_token_id,
            eos_token_id=self.lm_tokenizer.eos_token_id,
            max_length=20000,
            max_new_tokens=20000,
        )

    def forward(self, instruction, target_response=None):

        # get language instruction token length
        instruction_ids = self.lm_tokenizer(instruction, return_tensors="pt").input_ids
        instruction_token_len = instruction_ids.shape[1]

        # concat instruction and response
        if target_response is not None:
            instruction = instruction + target_response + self.lm_tokenizer.eos_token

        inputs = self.lm_tokenizer(instruction, return_tensors="pt")

        if target_response is not None:
            labels = inputs.input_ids.clone()
            # mask out prompt's label
            labels[:, :instruction_token_len] = -100
        else:
            labels = None

        # Step 4: Decoder Output
        if target_response is not None:
            outputs = self.lm_model(input_ids=inputs.input_ids.cuda(), attention_mask=inputs.attention_mask.cuda(), labels=labels.cuda())
        else:
            outputs = self.lm_model.generate(input_ids=inputs.input_ids.cuda(), attention_mask=inputs.attention_mask.cuda(), pad_token_id=self.lm_tokenizer.pad_token_id, generation_config=self.generation_config)
        return outputs

    def decode(self, outputs):
        output_text = self.lm_tokenizer.decode(outputs[0], skip_special_tokens=False)
        return output_text

    # save model and tokenizer
    def save(self, model_path, tokenizer_path, lora_path=None):
        # save model weights
        torch.save(self.state_dict(), model_path)
        print(f"Model saved at {model_path}")

        # save tokenizer
        self.lm_tokenizer.save_pretrained(tokenizer_path)
        print(f"Tokenizer saved at {tokenizer_path}")

    # load model and tokenizer
    def load(self, model_path, tokenizer_path, lora_path=None):
        # load tokenizer
        self.lm_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        print(f"Tokenizer loaded from {tokenizer_path}")

        # load model weights
        self.load_state_dict(torch.load(model_path, weights_only=True))
        print(f"Model loaded from {model_path}")