#!/usr/bin/env python3
"""
Minimal GPU memory error reproduction script for LUMI
Simple model with train/eval loop and memory monitoring
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import gc

def print_memory_info(label=""):
    """Print current GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / (1024**3)  # GB
        reserved = torch.cuda.memory_reserved() / (1024**3)    # GB
        print(f"[{label}] GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
    else:
        print(f"[{label}] CUDA not available")

class SimpleModel(nn.Module):
    def __init__(self, input_size=1000, hidden_size=512, output_size=100):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(), 
            nn.Linear(hidden_size, output_size)
        )
    
    def forward(self, x):
        return self.layers(x)

def create_dummy_dataset(num_samples=1000, input_size=1000, output_size=100, batch_size=32):
    """Create simple dummy dataset"""
    X = torch.randn(num_samples, input_size)
    y = torch.randn(num_samples, output_size)
    dataset = TensorDataset(X, y)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

def train_loop(model, dataloader, optimizer, device, epochs=3):
    """Simple training loop"""
    print("Starting training...")
    model.train()
    
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        print_memory_info("Epoch start")
        
        total_loss = 0
        for batch_idx, (data, target) in enumerate(dataloader):
            try:
                # Move to device
                data = data.to(device)
                target = target.to(device)
                
                print_memory_info(f"Batch {batch_idx} - data loaded")
                
                # Forward pass
                optimizer.zero_grad()
                output = model(data)
                loss = F.mse_loss(output, target)
                
                print_memory_info(f"Batch {batch_idx} - forward done")
                
                # Backward pass
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                
                print_memory_info(f"Batch {batch_idx} - backward done")
                
                if batch_idx % 10 == 0:
                    print(f"  Batch {batch_idx}, Loss: {loss.item():.4f}")
                
                # Force GPU sync - this might trigger the error
                torch.cuda.synchronize()
                
                if batch_idx >= 50:  # Limit batches for debugging
                    break
                    
            except Exception as e:
                print(f"\nERROR at epoch {epoch}, batch {batch_idx}:")
                print(f"Error: {e}")
                print_memory_info("ERROR")
                raise e
        
        avg_loss = total_loss / min(len(dataloader), 50)
        print(f"Epoch {epoch+1} average loss: {avg_loss:.4f}")
        
        # Cleanup between epochs
        gc.collect()
        torch.cuda.empty_cache()
        print_memory_info("Epoch end (after cleanup)")

def eval_loop(model, dataloader, device):
    """Simple evaluation loop"""
    print("\nStarting evaluation...")
    model.eval()
    print_memory_info("Eval start")
    
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(dataloader):
            try:
                data = data.to(device)
                target = target.to(device)
                
                output = model(data)
                loss = F.mse_loss(output, target)
                total_loss += loss.item()
                num_batches += 1
                
                # Force GPU sync
                torch.cuda.synchronize()
                
                if batch_idx >= 20:  # Limit for debugging
                    break
                    
            except Exception as e:
                print(f"\nEVAL ERROR at batch {batch_idx}:")
                print(f"Error: {e}")
                print_memory_info("EVAL ERROR")
                raise e
    
    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    print(f"Evaluation average loss: {avg_loss:.4f}")
    print_memory_info("Eval end")

def main():
    print("=== Simple GPU Memory Debug ===")
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    
    if torch.cuda.is_available():
        print(f"CUDA version: {torch.version.cuda}")
        print(f"Device name: {torch.cuda.get_device_name(0)}")
        print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.1f}GB")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    print_memory_info("Initial")
    
    # Create model and data
    model = SimpleModel(input_size=2000, hidden_size=1024, output_size=500).to(device)
    print_memory_info("Model loaded")
    
    train_dataloader = create_dummy_dataset(num_samples=2000, input_size=2000, output_size=500, batch_size=64)
    eval_dataloader = create_dummy_dataset(num_samples=500, input_size=2000, output_size=500, batch_size=32)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    try:
        # Train
        train_loop(model, train_dataloader, optimizer, device, epochs=2)
        
        # Evaluate  
        eval_loop(model, eval_dataloader, device)
        
        print("\nScript completed successfully!")
        print_memory_info("Final")
        
    except Exception as e:
        print(f"\nFINAL ERROR: {e}")
        print_memory_info("FINAL ERROR")
        raise e

if __name__ == "__main__":
    main()