#!/usr/bin/env python3
import os
import subprocess
import json
import torch
import gc
from pathlib import Path

def get_directory_size(path):
    """Get total size of directory in bytes"""
    if not os.path.exists(path):
        return 0
    
    total_size = 0
    try:
        for dirpath, dirnames, filenames in os.walk(path):
            for filename in filenames:
                filepath = os.path.join(dirpath, filename)
                if os.path.exists(filepath):
                    total_size += os.path.getsize(filepath)
    except (OSError, PermissionError):
        return 0
    return total_size

def format_size(size_bytes):
    """Convert bytes to human readable format"""
    if size_bytes == 0:
        return "0 B"
    
    size_names = ["B", "KB", "MB", "GB", "TB"]
    i = 0
    while size_bytes >= 1024 and i < len(size_names) - 1:
        size_bytes /= 1024.0
        i += 1
    
    return f"{size_bytes:.2f} {size_names[i]}"

def get_gpu_memory_usage():
    """Get GPU memory usage"""
    if torch.cuda.is_available():
        return {
            'allocated': torch.cuda.memory_allocated(),
            'reserved': torch.cuda.memory_reserved(),
            'max_allocated': torch.cuda.max_memory_allocated(),
            'max_reserved': torch.cuda.max_memory_reserved()
        }
    return None

def test_model_memory_usage_simulation(model_path, precision_configs, device="cuda:1"):
    """Test model memory usage with simulation method for different precision configurations"""
    
    print(f"\nTesting model: {model_path}")
    print("=" * 80)
    
    # Check if device is available
    if not torch.cuda.is_available():
        print("Error: CUDA not available")
        return {}
    
    if device not in [f"cuda:{i}" for i in range(torch.cuda.device_count())]:
        print(f"Error: Device {device} not available, available devices: {[f'cuda:{i}' for i in range(torch.cuda.device_count())]}")
        return {}
    
    print(f"Using device: {device}")
    print(f"Device properties: {torch.cuda.get_device_properties(device)}")
    
    # Set current device
    torch.cuda.set_device(device)
    
    memory_results = {}
    
    for config_name, naive_bit in precision_configs.items():
        print(f"\nTesting configuration: {config_name}")
        print(f"Naive_bit: {naive_bit}")
        print("-" * 60)
        
        # Clear GPU memory
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        
        try:
            # Record memory before loading
            before_memory = get_gpu_memory_usage()
            print(f"Memory before loading: {format_size(before_memory['allocated'])}")
            
            # Calculate theoretical memory usage based on model path and naive_bit
            model_name = os.path.basename(model_path).lower()
            
            # Estimate model parameter count
            if "7b" in model_name or "7b-distill" in model_name:
                total_params = 7 * 1024 * 1024 * 1024  # 7B parameters
            elif "8b" in model_name or "3-8b" in model_name:
                total_params = 8 * 1024 * 1024 * 1024  # 8B parameters
            else:
                total_params = 7 * 1024 * 1024 * 1024  # Default 7B
            
            # Calculate average precision based on naive_bit
            if len(naive_bit) == 1:
                avg_bits = naive_bit[0]
            else:
                # Mixed precision: use weighted average
                avg_bits = sum(naive_bit) / len(naive_bit)
            
            # Calculate theoretical memory usage (bytes)
            theoretical_memory_bytes = (total_params * avg_bits) // 8
            
            # Create smaller simulation tensor to test memory usage
            # Use smaller test size to avoid memory shortage
            test_memory_bytes = min(100 * 1024 * 1024, theoretical_memory_bytes // 16)  # Max 100MB
            
            # Ensure test memory doesn't exceed available GPU memory
            available_memory = torch.cuda.get_device_properties(device).total_memory
            test_memory_bytes = min(test_memory_bytes, available_memory // 4)
            
            # Calculate tensor size (number of elements)
            tensor_size = int(test_memory_bytes // 4)  # float16 is 2 bytes per element, ensure integer
            
            print(f"Theoretical memory: {format_size(theoretical_memory_bytes)}")
            print(f"Test memory: {format_size(test_memory_bytes)}")
            print(f"Tensor size: {tensor_size} elements")
            
            # Create simulation tensor (fix parameter format)
            with torch.cuda.device(device):
                dummy_tensor = torch.randn(tensor_size, dtype=torch.float16, device=device)
                # Ensure tensor is on GPU
                dummy_tensor = dummy_tensor.cuda()
            
            # Record memory after loading
            after_memory = get_gpu_memory_usage()
            print(f"Memory after simulation loading: {format_size(after_memory['allocated'])}")
            
            # Calculate model memory usage
            model_memory = after_memory['allocated'] - before_memory['allocated']
            print(f"Simulated model memory usage: {format_size(model_memory)}")
            print(f"Theoretical memory usage: {format_size(theoretical_memory_bytes)}")
            print(f"Peak memory: {format_size(after_memory['max_allocated'])}")
            
            # Estimate actual memory usage based on theoretical value
            if test_memory_bytes > 0:
                estimated_memory = theoretical_memory_bytes * (model_memory / test_memory_bytes)
            else:
                estimated_memory = 0
            
            memory_results[config_name] = {
                'naive_bit': naive_bit,
                'before_memory': before_memory,
                'after_memory': after_memory,
                'model_memory': model_memory,
                'theoretical_memory': theoretical_memory_bytes,
                'estimated_memory': estimated_memory,
                'peak_memory': after_memory['max_allocated'],
                'method': 'simulation',
                'tensor_size': tensor_size,
                'test_memory_bytes': test_memory_bytes
            }
            
            # Clean up simulation tensor
            del dummy_tensor
            gc.collect()
            torch.cuda.empty_cache()
            
        except Exception as e:
            print(f"Error during testing: {e}")
            memory_results[config_name] = {
                'error': str(e),
                'naive_bit': naive_bit
            }
    
    return memory_results

def check_model_space_usage():
    """Check space usage of different models"""
    
    # Get DATA_ROOT environment variable
    data_root = os.environ.get("DATA_ROOT")
    if not data_root:
        print("Error: DATA_ROOT environment variable not set, please run conda_activate.sh first")
        return
    
    print(f"Using DATA_ROOT: {data_root}")
    print("=" * 80)
    
    # Define model paths to check
    models_to_check = {
        "DeepSeek-R1-Distill-Qwen-7B": f"{data_root}/DeepSeek-R1-Distill-Qwen-7B",
        "Qwen3-8B": f"{data_root}/Qwen3-8B",
        "qwen7b-distill (quantized)": f"{data_root}/quantize_model/packed/qwen7b-distill",
        "qwen3-8b (quantized)": f"{data_root}/quantize_model/packed/qwen3-8b",
    }
    
    results = {}
    
    print("Checking base model space usage:")
    print("-" * 80)
    
    for model_name, model_path in models_to_check.items():
        if os.path.exists(model_path):
            size_bytes = get_directory_size(model_path)
            size_formatted = format_size(size_bytes)
            results[model_name] = {
                "path": model_path,
                "size_bytes": size_bytes,
                "size_formatted": size_formatted,
                "exists": True
            }
            print(f"✓ {model_name}")
            print(f"  Path: {model_path}")
            print(f"  Size: {size_formatted}")
            print()
        else:
            results[model_name] = {
                "path": model_path,
                "size_bytes": 0,
                "size_formatted": "0 B",
                "exists": False
            }
            print(f"✗ {model_name}")
            print(f"  Path: {model_path}")
            print(f"  Status: Not found")
            print()
    
    # Sort by size and display summary
    print("=" * 80)
    print("Space usage summary (sorted by size):")
    print("-" * 80)
    
    sorted_results = sorted(results.items(), key=lambda x: x[1]["size_bytes"], reverse=True)
    
    for model_name, info in sorted_results:
        if info["exists"]:
            print(f"{model_name:<50} {info['size_formatted']:>15}")
        else:
            print(f"{model_name:<50} {'Not found':>15}")
    
    # Calculate total space usage
    total_size = sum(info["size_bytes"] for info in results.values() if info["exists"])
    print("-" * 80)
    print(f"{'Total':<50} {format_size(total_size):>15}")
    
    return results

def test_precision_memory_usage():
    """Test memory usage of different precision configurations"""
    
    data_root = os.environ.get("DATA_ROOT")
    if not data_root:
        print("Error: DATA_ROOT environment variable not set, please run conda_activate.sh first")
        return
    
    # Define precision configurations (using naive_bit parameter)
    precision_configs = {
        "3bit_only": [3],
        "4bit_only": [4], 
        "3bit_4bit_mixed": [3, 4],
        "3bit_4bit_5bit": [3, 4, 5],
        "full_precision": [3, 4, 5, 6, 7, 8]
    }
    
    # Test model paths
    test_models = [
        f"{data_root}/quantize_model/packed/qwen7b-distill",
        f"{data_root}/quantize_model/packed/qwen3-8b"
    ]
    
    all_memory_results = {}
    
    for model_path in test_models:
        if os.path.exists(model_path):
            print(f"\n{'='*80}")
            print(f"Testing model memory usage: {model_path}")
            print(f"{'='*80}")
            
            memory_results = test_model_memory_usage_simulation(model_path, precision_configs)
            all_memory_results[model_path] = memory_results
    
    # Output summary
    print(f"\n{'='*80}")
    print("Memory usage summary")
    print(f"{'='*80}")
    
    for model_path, results in all_memory_results.items():
        model_name = os.path.basename(model_path)
        print(f"\nModel: {model_name}")
        print("-" * 60)
        
        for config_name, result in results.items():
            if 'error' not in result:
                model_memory = result['model_memory']
                peak_memory = result['peak_memory']
                naive_bit = result['naive_bit']
                theoretical_memory = result['theoretical_memory']
                estimated_memory = result['estimated_memory']
                
                print(f"{config_name:<25} Memory: {format_size(model_memory):>15} Theoretical: {format_size(theoretical_memory):>15} Estimated: {format_size(estimated_memory):>15} naive_bit: {naive_bit}")
            else:
                print(f"{config_name:<25} Error: {result['error']}")
    
    # Save results to JSON file
    output_file = "precision_memory_usage.json"
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(all_memory_results, f, indent=2, ensure_ascii=False)
    
    print(f"\nDetailed results saved to: {output_file}")
    
    return all_memory_results

def check_disk_space():
    """Check available disk space"""
    data_root = os.environ.get("DATA_ROOT")
    if not data_root:
        return
    
    try:
        # Get available space on disk containing DATA_ROOT
        statvfs = os.statvfs(data_root)
        free_bytes = statvfs.f_frsize * statvfs.f_bavail
        total_bytes = statvfs.f_frsize * statvfs.f_blocks
        used_bytes = total_bytes - free_bytes
        
        print("\n" + "=" * 80)
        print("Disk space information:")
        print("-" * 80)
        print(f"Disk path: {data_root}")
        print(f"Total space: {format_size(total_bytes)}")
        print(f"Used space: {format_size(used_bytes)}")
        print(f"Free space: {format_size(free_bytes)}")
        print(f"Usage rate: {(used_bytes/total_bytes)*100:.1f}%")
        print("=" * 80)
        
    except Exception as e:
        print(f"Cannot get disk space information: {e}")

if __name__ == "__main__":
    print("Model space usage and memory usage test")
    print("=" * 80)
    
    # Check model space usage
    print("1. Checking disk space usage...")
    results = check_model_space_usage()
    
    # Check disk space
    check_disk_space()
    
    # Test memory usage of different precision configurations
    print("\n2. Testing memory usage of different precision configurations...")
    memory_results = test_precision_memory_usage()
    
    print("\nTest completed!")
