'''
Shared code for the comparison of different methods on UCI datasets.
'''
import os
import sys
sys.path.append('.')
import time
import numpy as np
import torch
from era5_dataset import ERA5
from torch.utils.data import DataLoader
import subprocess

CHECK_NAN_GRAD = False
torch.autograd.set_detect_anomaly(CHECK_NAN_GRAD)


path0 = os.path.dirname(sys.argv[0])
path_summary = os.path.join(path0, 'summary')

def prepare_case( 
                batch_size: int=512,
                GPU_ID: int = 0,
                max_samples: int = None,
                ) -> dict:
    '''
    Prepare the test case, including:
    - Create folders
    - Load the dataset
    - Split the dataset into train and test sets
    '''
    os.makedirs(path_summary, exist_ok=True)
    
    train_set = ERA5(image_size=64, split="train", download=False, gpu_id=GPU_ID, max_samples=max_samples)
    test_set = ERA5(image_size=64, split="test", download=False, gpu_id=GPU_ID, max_samples=max_samples)

    dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=False)

    X_train_tensor = train_set.x
    y_train_tensor = train_set.y
    X_test_tensor = test_set.x
    y_test_tensor = test_set.y

    return {
        'train_set': train_set,
        'test_set': test_set,
        'dataloader': dataloader,
        'X_train_tensor': X_train_tensor,
        'y_train_tensor': y_train_tensor,
        'X_test_tensor': X_test_tensor,
        'y_test_tensor': y_test_tensor,
        'dim_input': train_set.dim_input,
        'dim_output': train_set.dim_output,
    }

def set_seed(seed: int):
    '''
    Set the random seed for reproducibility.
    '''
    np.random.seed(seed)
    
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False

def fname_summary(model_name: str) -> str:
    return os.path.join(path_summary, 'summary-%s-era5.csv'%(model_name))

def assign_gpu(idx, num_gpus):
    """Assign a GPU ID based on the job index for round-robin assignment"""
    if num_gpus == 0:  # No GPUs available
        return 0
    return idx % num_gpus

def get_gpu_memory_info():
    """Get GPU memory information using nvidia-smi"""
    try:
        result = subprocess.run(['nvidia-smi', '--query-gpu=index,memory.total,memory.used,memory.free', 
                               '--format=csv,noheader,nounits'], 
                              capture_output=True, text=True, timeout=10)
        if result.returncode != 0:
            return None
        
        gpu_info = []
        for line in result.stdout.strip().split('\n'):
            if line.strip():
                parts = line.split(', ')
                if len(parts) >= 4:
                    gpu_info.append({
                        'gpu_id': int(parts[0]),
                        'total_memory': int(parts[1]) * 1024 * 1024,  # Convert MB to bytes
                        'used_memory': int(parts[2]) * 1024 * 1024,
                        'free_memory': int(parts[3]) * 1024 * 1024
                    })
        return gpu_info
    except Exception as e:
        print(f"Error getting GPU memory info: {e}")
        return None

def assign_gpu_by_memory(gpu_id: int):
    """Assign GPU with the most available memory"""
    if not torch.cuda.is_available():
        return 0
    
    num_gpus = torch.cuda.device_count()
    if num_gpus == 0:
        return 0
    
    # Wait for random time between 0 and 10 seconds
    t = np.random.rand()*30 + (gpu_id+1)*10
    print(f"Waiting for {t} seconds")
    time.sleep(t)
    
    # Get memory information using nvidia-smi
    gpu_memory = get_gpu_memory_info()
    
    if gpu_memory is None or len(gpu_memory) == 0:
        # Fallback to round-robin assignment if nvidia-smi fails
        print("nvidia-smi failed, falling back to round-robin assignment")
        return gpu_id % num_gpus
    
    # Sort by free memory (descending) and return GPU with most free memory
    gpu_memory.sort(key=lambda x: x['free_memory'], reverse=True)
    best_gpu = gpu_memory[0]['gpu_id']
    
    print(f"Assigning GPU {best_gpu} with {gpu_memory[0]['free_memory']} bytes free memory")
    
    return best_gpu


if __name__ == "__main__":
    
    gpu_memory = get_gpu_memory_info()
    
    print(gpu_memory)

