from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoConfig
import torch
import torch.nn as nn
import torch.nn.functional as F

def weights_init(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight.data)

def get_untrained_model_and_tokenizer(name="EleutherAI/gpt-neo-1.3B", cache_dir = "./huggingface_cache"):
    """
    Loads an untrained model
    """
    tokenizer = AutoTokenizer.from_pretrained(name)
    config = AutoConfig.from_pretrained(name, cache_dir=cache_dir)
    model = AutoModelForCausalLM.from_config(config, )

    return model, tokenizer

def get_checkpoint(checkpoint_filename: str = None, device: str = "cpu", name = "EleutherAI/gpt-neo-1.3B", cache_dir = "./huggingface_cache"):
    """
    string = checkpoint
    None = random weights
    "pretrained" = original pre-trained weights
    """
    model, tokenizer = get_untrained_model_and_tokenizer(name=name, cache_dir=cache_dir)
    model.eval().to(device)
    model.generation_config.pad_token_id = tokenizer.pad_token_id

    if checkpoint_filename is not None and checkpoint_filename != "pretrained":
        model.from_pretrained(
            checkpoint_filename
        )
        print(f"Big success! We loaded a model from: {checkpoint_filename}")
    elif checkpoint_filename == None:
        model = model.apply(weights_init)
    elif checkpoint_filename == "pretrained":
        model = model.from_pretrained(name)
        model.eval().to(device)
        pass ## do nothing
    else:
        raise ValueError(f"Invalid checkpoint_filename: {checkpoint_filename}")
    return model, tokenizer

checkpoint_filename = "/home/XXXX-4/repos/nesim/training/gpt_neo/checkpoints/apply_nesim_every_n_steps_1_nesim_config_scale_50_shrink_factor_[9.0]_layer_names_all_layers_c_fc_checkpoint_every_n_steps_1_num_warmup_steps_3000_batch_size_8_context_length_128/checkpoint-1"

model, tokenizer = get_checkpoint(checkpoint_filename = checkpoint_filename, device = "cuda", name = "EleutherAI/gpt-neo-1.3B", cache_dir = "/research/XXXX-1/huggingface_cache_dir")

input_prompt = "An apple a day"
input_tokens = tokenizer(input_prompt, return_tensors="pt").to("cuda")

with torch.no_grad():
    output_tokens = model.generate(
        input_tokens["input_ids"],
        max_new_tokens=50,  
        do_sample=True,  # Enable sampling (can change to False for greedy)
        top_k=10,  # Optional: controls randomness by selecting from top-k tokens
        top_p=0.8,  # Optional: controls randomness with nucleus sampling
    )

generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)

print("Generated text:", generated_text)