

import logging 
from enum import Enum 
from typing import List 

import numpy 
import torch 

logger =logging .getLogger (__name__ )


class Split (Enum ):
    train =0 
    valid =1 
    test =2 


def compile_helpers ():
    """Compile C++ helper functions at runtime. Make sure this is invoked on a single process.
    """
    import os 
    import subprocess 

    command =["make","-C",os .path .abspath (os .path .dirname (__file__ ))]
    if subprocess .run (command ).returncode !=0 :
        import sys 

        log_single_rank (logger ,logging .ERROR ,"Failed to compile the C++ dataset helper functions")
        sys .exit (1 )


def log_single_rank (logger :logging .Logger ,*args ,rank =0 ,**kwargs ):
    """If torch distributed is initialized, log only on rank

    Args:
        logger (logging.Logger): The logger to write the logs

        rank (int, optional): The rank to write on. Defaults to 0.
    """
    if torch .distributed .is_initialized ():
        if torch .distributed .get_rank ()==rank :
            logger .log (*args ,**kwargs )
    else :
        logger .log (*args ,**kwargs )


def normalize (weights :List [float ])->List [float ]:
    """Do non-exponentiated normalization

    Args:
        weights (List[float]): The weights

    Returns:
        List[float]: The normalized weights
    """
    w =numpy .array (weights ,dtype =numpy .float64 )
    w_sum =numpy .sum (w )
    w =(w /w_sum ).tolist ()
    return w 
