import os
import json
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, GPT2LMHeadModel, GPT2Config, GPT2TokenizerFast
from memgpt.constants import CONFIGS_DIR
from peft import PeftModel


def initialize_model_for_pretraining(model_args, resume_from_checkpoint=None, use_special_dblookup_tokens=False):
    """Load a pretrained model and tokenizer based on the provided model arguments.
    
    Args:
        model_args: Model configuration arguments.
        resume_from_checkpoint (str, optional): Path to the checkpoint directory to resume training from.
    """
    model_name_or_path = model_args.model_name_or_path

    if resume_from_checkpoint:
        return load_model_from_checkpoint(resume_from_checkpoint, model_args)

    # Otherwise, initialize the model and tokenizer from scratch
    if "gpt2" in model_name_or_path:
        return load_gpt2_model(model_name_or_path, use_special_dblookup_tokens)

    elif "tiny-llama2" in model_name_or_path:
        return load_tiny_llama2_model(model_name_or_path, model_args, use_special_dblookup_tokens)

    else:
        # Handle other models (placeholder for future implementation)
        return load_custom_model(model_name_or_path, model_args)


def load_model_for_ft_baseline(model_args, resume_from_checkpoint=None, use_special_dblookup_tokens=False):
    """
    Load a LLaMa3 model and tokenizer, ensuring the pad token is properly set.
    
    Args:
        model_args: An object containing model configuration parameters including:
            - model_name_or_path: Path or identifier of the model to load
            - trust_remote_code: Whether to trust remote code when loading the model
    
    Returns:
        tuple: (model, tokenizer) - The loaded model and tokenizer
    """
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        resume_from_checkpoint,
        trust_remote_code=model_args.trust_remote_code,
        use_fast=True
    )
        
    if use_special_dblookup_tokens:
        tokenizer,_ = add_dblookup_special_tokens(tokenizer, config=None)
        print(f"vocab_size: {len(tokenizer)}")    

    # Load the model
    model = AutoModelForCausalLM.from_pretrained(
        resume_from_checkpoint,
        trust_remote_code=model_args.trust_remote_code
    )
    
    # Update model configuration with pad token ID
    model.config.pad_token_id = tokenizer.pad_token_id
    
    # Resize token embeddings if needed
    if model.get_input_embeddings().weight.shape[0] != len(tokenizer):
        model.resize_token_embeddings(len(tokenizer))
        print(f"Model vocab size updated: {model.config.vocab_size}, Tokenizer vocab size: {len(tokenizer)}")
    
    print_model_info(model_args.model_name_or_path, model)
    return model, tokenizer


def load_model_from_checkpoint(resume_from_checkpoint, model_args):
    """Load the model and tokenizer from the specified checkpoint."""
    print(f"Loading model and tokenizer from checkpoint: {resume_from_checkpoint}")
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            resume_from_checkpoint,
            trust_remote_code=model_args.trust_remote_code,
            use_fast=True
        )
        model = AutoModelForCausalLM.from_pretrained(
            resume_from_checkpoint,
            trust_remote_code=model_args.trust_remote_code
        )
        tokenizer.pad_token = tokenizer.eos_token
        return model, tokenizer
    except Exception as e:
        raise ValueError(f"Failed to load model or tokenizer from checkpoint {resume_from_checkpoint}: {str(e)}")


def load_gpt2_model(model_name_or_path, use_special_dblookup_tokens=False):
    """Load GPT-2 model and tokenizer."""
    # TODO: Add support for different GPT-2 model sizes
    # model_name_or_path = "gpt2"
    # config = GPT2Config()
    config = GPT2Config.from_pretrained(model_name_or_path)
    tokenizer = GPT2TokenizerFast.from_pretrained(model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token

    if use_special_dblookup_tokens:
        tokenizer, config = add_dblookup_special_tokens(tokenizer, config)

    model = GPT2LMHeadModel(config)

    print_model_info(model_name_or_path, model)
    # import pdb; pdb.set_trace()
    return model, tokenizer

def add_dblookup_special_tokens(tokenizer, config=None):
    
    db_tokens = {
        "entity": "<|db_entity|>",
        "relationship": "<|db_relationship|>",
        "return": "<|db_return|>",
        "end": "<|db_end|>"
    }
    
    # Convert dictionary to list of tokens for adding to tokenizer
    new_tokens = list(db_tokens.values())
    
    # Add tokens to tokenizer
    num_added_tokens = tokenizer.add_special_tokens({"additional_special_tokens": new_tokens})
    print(f"Added {num_added_tokens} tokens to the vocabulary")

    if config is not None:
        config.vocab_size = config.vocab_size + num_added_tokens    
        print(f"Updated vocab_size to {config.vocab_size}")
    return tokenizer, config 

def load_tiny_llama2_tokenizer(add_special_tokens=False):
    tokenizer = AutoTokenizer.from_pretrained(os.path.join(CONFIGS_DIR, "tiny-llama/tiny-llama2"))
    tokenizer.pad_token = tokenizer.eos_token   

    if add_special_tokens:
        tokenizer, _ = add_dblookup_special_tokens(tokenizer)
    return tokenizer

def load_tiny_llama2_model(model_name_or_path, model_args, use_special_dblookup_tokens=False):
    """Load a Tiny Llama model based on its configuration."""
    model_path = os.path.join(CONFIGS_DIR, f"tiny-llama/{model_name_or_path}")
    if not os.path.exists(model_path):
        raise ValueError(f"Model {model_name_or_path} not found in {CONFIGS_DIR}")
    
    tokenizer = AutoTokenizer.from_pretrained(os.path.join(CONFIGS_DIR, "tiny-llama/tiny-llama2"))
    tokenizer.pad_token = tokenizer.eos_token   

    config = AutoConfig.from_pretrained(model_path)

    if use_special_dblookup_tokens:
        tokenizer, config = add_dblookup_special_tokens(tokenizer, config)

    model = AutoModelForCausalLM.from_config(config, trust_remote_code=model_args.trust_remote_code)

    print_model_info(model_name_or_path, model)
    return model, tokenizer


def load_custom_model(model_name_or_path, model_args):
    """Handle loading other models that are not yet explicitly supported."""
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path,
            trust_remote_code=model_args.trust_remote_code,
            use_fast=True
        )
        tokenizer.pad_token = tokenizer.eos_token
    except Exception as e:
        raise ValueError(f"Failed to load tokenizer for {model_name_or_path}: {str(e)}")
    
    raise NotImplementedError(f"Model {model_name_or_path} initialization not implemented.")


def load_lora_model(model_args, tokenizer_only=False):
    """Load a pretrained Lora model and tokenizer based on the provided model arguments."""
    model_name_or_path = model_args.model_name_or_path

    if "ft" in model_name_or_path or "tune" in model_name_or_path: 
        config_file = "llama-8b-ft/lora-ft-hf" if "8b" in model_name_or_path.lower() else None
        if not config_file:
            raise ValueError(f"Model {model_name_or_path} not found in {CONFIGS_DIR}")
        
        with open(os.path.join(CONFIGS_DIR, f"{config_file}.json"), "r") as f:
            configs = json.load(f)

        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        if tokenizer_only:
            return None, tokenizer
        
        base_model = AutoModelForCausalLM.from_pretrained(configs['base_model'], device_map="auto")

        # Ensure PAD token exists and resize token embeddings
        if len(tokenizer) != base_model.config.vocab_size:
            if tokenizer.pad_token_id is None:
                tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
            base_model.config.pad_token_id = tokenizer.pad_token_id
            base_model.resize_token_embeddings(len(tokenizer))

            print(f"Added pad token to tokenizer: {tokenizer.pad_token} and resized token embeddings to {len(tokenizer)}")

        model = PeftModel.from_pretrained(base_model, model_name_or_path)

        print_model_info(model_name_or_path, model)
        return model, tokenizer
    else:
        raise ValueError(f"Model {model_name_or_path} not implemented for LoRA")


def load_llama3_for_instruction_tuning(model_args, model_kwargs):
    """
    Load a LLaMa3 model and tokenizer, ensuring the pad token is properly set.
    
    Args:
        model_args: An object containing model configuration parameters including:
            - model_name_or_path: Path or identifier of the model to load
            - trust_remote_code: Whether to trust remote code when loading the model
    
    Returns:
        tuple: (model, tokenizer) - The loaded model and tokenizer
    """
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        trust_remote_code=model_args.trust_remote_code,
        use_fast=True
    )
    
    # Add pad token if it doesn't exist or is the same as the EOS token
    if tokenizer.pad_token_id is None or tokenizer.pad_token_id == tokenizer.eos_token_id:
        tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
        
        # Log token IDs for verification
        print(f"EOS Token ID: {tokenizer.eos_token_id}")
        print(f"PAD Token ID: {tokenizer.pad_token_id}")
    
    # Load the model
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        **model_kwargs
    )
    
    # Update model configuration with pad token ID
    model.config.pad_token_id = tokenizer.pad_token_id
    
    # Resize token embeddings if needed
    if model.get_input_embeddings().weight.shape[0] != len(tokenizer):
        model.resize_token_embeddings(len(tokenizer))
        print(f"Model vocab size updated: {model.config.vocab_size}, Tokenizer vocab size: {len(tokenizer)}")
    
    print_model_info(model_args.model_name_or_path, model)
    return model, tokenizer

def load_llama3_for_sft(model_args, model_kwargs, training_args):
    training_args.model_init_kwargs = model_kwargs
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
    )
    # BUG
    print(f"==== SFT ====")
    tokenizer.pad_token = tokenizer.eos_token
    model = model_args.model_name_or_path
    
    print_model_info(model_args.model_name_or_path, model)
    return model, tokenizer

def print_model_info(model_name_or_path, model):
    """Print model architecture and parameter information including non-embedding parameters."""
    total_params = sum(p.numel() for p in model.parameters())
    
    # Exclude parameters in embedding layers
    embedding_keywords = ['embed', 'embedding', 'wte']

    embedding_params = sum(
        p.numel()
        for name, p in model.named_parameters()
        if any(keyword in name.lower() for keyword in embedding_keywords) and p.requires_grad
    )
    
    non_embedding_params = total_params - embedding_params

    def pretty(num):
        return f"{num / 1e6:.1f}M" if num < 1e9 else f"{num / 1e9:.1f}B"

    print(f"==== Model ====")
    print(f"Model path: {model_name_or_path}")
    print(f"Total parameters: {pretty(total_params)}")
    print(f"Non-embedding parameters: {pretty(non_embedding_params)}")


def merge_model_save(model, tokenizer, save_dir):
    # Merge LoRA model weights into the base model
    merged_model = model.merge_and_unload()
    print(merged_model)  # Should not contain 'lora'
    print("LoRA model weights merged successfully.")

    # Save the merged model checkpoint
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    import pdb; pdb.set_trace()
    merged_model.save_pretrained(save_dir)
    tokenizer.save_pretrained(save_dir)

    print(f"Merged model and tokenizer saved to {save_dir}")


if __name__ == "__main__":
    import argparse
    model_args = argparse.ArgumentParser()
    model_args.add_argument("--model-name-or-path", type=str, default="gpt2")
    model_args.add_argument("--trust-remote-code", action="store_true", help="Trust remote code when loading model")
    model_args = model_args.parse_args()

    # model, tokenizer = initialize_model_for_pretraining(model_args)
    # load_tiny_llama2_model
    model, tokenizer = load_gpt2_model(model_args.model_name_or_path, model_args)
    import pdb; pdb.set_trace()
    # model, tokenizer = load_tiny_llama2_model(model_args.model_name_or_path, model_args)
