import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm

def train_autoencoder(model, train_data, val_data=None, num_epochs=100, batch_size=64, 
                      learning_rate=0.001, device='cuda' if torch.cuda.is_available() else 'cpu'):
    """
    Train autoencoder model
    
    Args:
        model: Autoencoder model
        train_data: Training data (features)
        val_data: Optional validation data (features)
        num_epochs: Number of training epochs
        batch_size: Batch size
        learning_rate: Learning rate
        device: Device to run on
        
    Returns:
        Final training loss
    """
    # Create data loader
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    
    # Loss function and optimizer
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Training loop
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for batch in train_loader:
            batch = batch.to(device)
            
            # Forward pass
            outputs, l1_loss = model(batch)
            reconstruction_loss = criterion(outputs, batch)
            loss = reconstruction_loss + l1_loss  # Add L1 regularization loss
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item() * batch.size(0)
        
        epoch_loss /= len(train_data)
        
        # Validate if validation data provided
        if val_data is not None and (epoch + 1) % 10 == 0:
            val_loss = evaluate_autoencoder(model, val_data, batch_size, device)
            print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.6f}, Val Loss: {val_loss:.6f}")
        #else:
        #    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.6f}")
    
    return epoch_loss

def evaluate_autoencoder(model, data, batch_size=64, 
                         device='cuda' if torch.cuda.is_available() else 'cpu'):
    """
    Evaluate autoencoder model
    
    Args:
        model: Autoencoder model
        data: Data to evaluate on (features)
        batch_size: Batch size
        device: Device to run on
        
    Returns:
        Evaluation loss
    """
    # Create data loader
    data_loader = DataLoader(data, batch_size=batch_size, shuffle=False)
    
    # Loss function
    criterion = nn.MSELoss()
    
    # Evaluation loop
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for batch in data_loader:
            batch = batch.to(device)
            
            # Forward pass
            outputs, l1_loss = model(batch)
            reconstruction_loss = criterion(outputs, batch)
            loss = reconstruction_loss + l1_loss  # Add L1 regularization loss
            
            total_loss += loss.item() * batch.size(0)
    
    total_loss /= len(data)
    return total_loss
