import os
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Model configuration
MODEL_NAME = "HuggingFaceTB/SmolLM-360M-Instruct"  # You can change this to your preferred model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LENGTH = 1024
TEMPERATURE = 1.0
TOP_P = 0.9

def load_model():
    """Load the model and tokenizer"""
    print(f"Loading model {MODEL_NAME}...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
        device_map="auto"
    )
    return model, tokenizer

def prompt_model(prompt: str, model, tokenizer) -> str:
    """Generate text using the local model"""
    # Prepare the prompt with system message
    full_prompt = f"You are a helpful assistant.\n\nUser: {prompt}\n\nAssistant:"
    
    # Tokenize the input
    inputs = tokenizer(full_prompt, return_tensors="pt").to(DEVICE)
    
    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            inputs["input_ids"],
            max_length=MAX_LENGTH,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=True
        )
    
    # Decode and clean up the response
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract only the assistant's response
    response = response.split("Assistant:")[-1].strip()
    
    return response

def main():
    # Set up directories
    model_name_short = MODEL_NAME.split("/")[-1]
    ai_written_text_directory = f'dataset_creation/ghostbuster-data/essay/{model_name_short}'
    prompt_text_directory = 'dataset_creation/ghostbuster-data/essay/prompts'
    
    # Create output directory if it doesn't exist
    os.makedirs(ai_written_text_directory, exist_ok=True)
    
    # Load model and tokenizer
    model, tokenizer = load_model()
    
    # Process each prompt file
    for index, filename in enumerate(os.listdir(prompt_text_directory)):
        print(f"Processing filename: {filename}")
        
        prompt_text_file = os.path.join(prompt_text_directory, filename)
        
        if not os.path.isfile(prompt_text_file):
            continue
        
        # Read prompt from file
        with open(prompt_text_file, 'r') as file:
            prompt = file.read()
        
        # Get AI response
        try:
            ai_generated_text = prompt_model(prompt, model, tokenizer)
            
            # Save response to file
            ai_generated_text_file = os.path.join(ai_written_text_directory, filename)
            with open(ai_generated_text_file, "w") as text_file:
                text_file.write(ai_generated_text)
            
            print(f"Successfully processed {filename}")
            
        except Exception as e:
            print(f"Error processing {filename}: {str(e)}")

if __name__ == "__main__":
    main()