import numpy as np
from itertools import combinations_with_replacement
from sklearn.preprocessing import RobustScaler, Normalizer
from .helper import normalize_data_point
from .helper_basis_symbolic import load_and_precompile_functions_joblib, compute_with_precompiled_functions

# Global variable to cache the precompiled functions
_precompiled_functions_cache = None

def select_feat_extractor(env_name, states, cfg):
    if env_name== 'Walker2d-v4':
        feature_expectations = walker_feat_extract(states, cfg)
    elif env_name== 'HalfCheetah-v4':
        feature_expectations = cheetah_feat_extract(states, cfg)        
    elif env_name== 'Hopper-v4':
        feature_expectations = hopper_feat_extract(states, cfg)
    elif env_name== 'Ant-v4':
        feature_expectations = ant_feat_extract(states, cfg)    
    else:
        raise NotImplementedError
    return feature_expectations


#  ██████ ██   ██ ███████ ███████ ████████  █████  ██   ██ 
# ██      ██   ██ ██      ██         ██    ██   ██ ██   ██ 
# ██      ███████ █████   █████      ██    ███████ ███████ 
# ██      ██   ██ ██      ██         ██    ██   ██ ██   ██ 
#  ██████ ██   ██ ███████ ███████    ██    ██   ██ ██   ██                                    
def cheetah_feat_extract(states, cfg):
    global _precompiled_functions_cache
    
    min_vals = 1.3 * np.array([-0.600, -3.220, -0.644, -0.858, -0.555, -1.072, -1.108, -0.681, -3.315, -3.527, -6.308, -20.155, -24.883, -23.288, -22.412, -25.717, -26.089])
    max_vals = 1.3 * np.array([0.378, 3.802, 0.910, 0.865, 0.873, 0.814, 1.017, 0.664, 3.221, 3.184, 6.993, 19.490, 23.201, 20.072, 25.405, 26.938, 23.360])
    states = normalize_data_point(states, min_vals, max_vals)
    
    use_norm = cfg["normalize_feats"]    
    if use_norm:
        states = Normalizer().fit_transform(states.reshape(1, -1)).squeeze()
    else:
        states = states
    
    feats = []
    feat_selection = cfg["feats_method"]
    
    if feat_selection == 'first':
        for ind in range(17):
            feats.append(states[ind])
            
    elif feat_selection == 'random':
        feats.append(states[2] ** 2)
        feats.append(states[4] * states[5])
        feats.append(states[6] * states[7])

        feats.append(states[6])
        feats.append(states[3] * states[10])
        feats.append(states[11] * states[12])

        feats.append(states[11])
        feats.append(states[14] * states[15])
        feats.append(states[16] * states[0])
        
        feats.append(states[11] * states[4])
        feats.append(states[4] ** 2)              
        feats.append(states[13])
            
    elif feat_selection == 'manual':
        feats.append(states[0])
        feats.append(states[4])
        feats.append(states[8])
        
        feats.append(states[9])
        feats.append(states[13])
        feats.append(states[14])
        
        feats.append(states[13] ** 2)
        feats.append(states[2])
        feats.append(states[4] * states[5])
        
        feats.append(states[11])
        feats.append(states[14] * states[12])
        feats.append(states[13] * states[0])
        

    elif feat_selection == 'proposed':        
        notex = cfg["path_to_basis"]
        if _precompiled_functions_cache is None:
            filepath = f'feature4irl/pickles/HalfCheetah-v4_basis_{notex}.joblib'
            _precompiled_functions_cache = load_and_precompile_functions_joblib(filepath)
        feats = compute_with_precompiled_functions(_precompiled_functions_cache, states)
            
    else:
        NotImplementedError()
        
    return np.array(feats)


#  ██     ██  █████  ██      ██   ██ ███████ ██████  
#  ██     ██ ██   ██ ██      ██  ██  ██      ██   ██ 
#  ██  █  ██ ███████ ██      █████   █████   ██████  
#  ██ ███ ██ ██   ██ ██      ██  ██  ██      ██   ██ 
#   ███ ███  ██   ██ ███████ ██   ██ ███████ ██   ██ 
                                                   
def walker_feat_extract(states, cfg):

    global _precompiled_functions_cache
    
    min_vals = 2 * np.array([0.3, -1.000, -1.212, -2.367, -1.187, -1.279, -2.346, -1.167, -2.909, -6.356, -10.000, -10.000, -10.000, -10.000, -10.000, -10.000, -10.000])
    max_vals = 2 * np.array([1.328, 0.121, 0.262, 0.178, 1.153, 0.243, 0.224, 1.167, 1.365, 1.912, 10.000, 10.000, 10.000, 10.000, 10.000, 10.000, 10.000])
    states = normalize_data_point(states, min_vals, max_vals)
    
    use_norm = cfg["normalize_feats"]
    if use_norm:
        states = Normalizer().fit_transform(states.reshape(1, -1)).squeeze()
    else:
        states = states
            
    # select subset
    feats = []
    feat_selection = cfg["feats_method"]
    
    if feat_selection == 'first':
        for ind in range(17):
            feats.append(states[ind])
            
    elif feat_selection == 'random':
        feats.append(states[3] ** 2)
        feats.append(states[2] * states[5])
        feats.append(states[7] * states[7])

        feats.append(states[6])
        feats.append(states[3] * states[10])
        feats.append(states[1] * states[2])

        feats.append(states[10])
        feats.append(states[16] * states[7])
        feats.append(states[2] * states[0])
        
        feats.append(states[11] * states[4])
        feats.append(states[4] ** 2)              
        feats.append(states[13])
            
    elif feat_selection == 'manual':
        feats.append(states[0])
        feats.append(states[4] ** 2)
        feats.append(states[8])
        
        feats.append(states[9])
        feats.append(states[15])
        feats.append(states[16])
        
        feats.append(states[13] ** 2)
        feats.append(states[2])
        feats.append(states[4] * states[5])
        
        feats.append(states[11])
        feats.append(states[14] * states[12])
        feats.append(states[13] * states[0])

    elif feat_selection == 'proposed':
        notex = cfg["path_to_basis"]
        if _precompiled_functions_cache is None:
            filepath = f'feature4irl/pickles/Walker2d-v4_basis_{notex}.joblib'
            _precompiled_functions_cache = load_and_precompile_functions_joblib(filepath)
        feats = compute_with_precompiled_functions(_precompiled_functions_cache, states)
        
    else:
        NotImplementedError()
        
    return np.array(feats)
    


#  █████  ███    ██ ████████ 
# ██   ██ ████   ██    ██    
# ███████ ██ ██  ██    ██    
# ██   ██ ██  ██ ██    ██    
# ██   ██ ██   ████    ██    
def ant_feat_extract(states, cfg):
    global _precompiled_functions_cache
    
    min_vals = 2 * np.array([0.136, -1.000, -1.000, -1.000, -1.000, -0.686, -0.100, -0.680, -1.361, -0.680, -1.357, -0.678, -0.099, -3.635, -4.142, -3.867, -10.950, -8.932, -8.740, -16.185, -14.044, -15.993, -18.374, -16.869, -18.456, -16.225, -13.949])
    max_vals = 2 * np.array([1.000, 1.000, 1.000, 1.000, 1.000, 0.681, 1.351, 0.683, 0.100, 0.682, 0.100, 0.681, 1.364, 4.081, 4.221, 5.764, 8.832, 8.763, 9.109, 16.911, 18.407, 16.541, 14.151, 16.197, 14.252, 16.509, 17.901])
    states = normalize_data_point(states, min_vals, max_vals)
    
    use_norm = cfg["normalize_feats"]    
    if use_norm:
        states = Normalizer().fit_transform(states.reshape(1, -1)).squeeze()
    else:
        states = states
    
    # select subset
    feats = []
    feat_selection = cfg["feats_method"]
    
    if feat_selection == 'first':
        feats = states[:27]
            
    elif feat_selection == 'random':
        feats.append(states[5] ** 2)
        feats.append(states[2] * states[5])
        feats.append(states[2] * states[7])

        feats.append(states[7])
        feats.append(states[8] * states[10])
        feats.append(states[5] * states[2])

        feats.append(states[3])
        feats.append(states[12] * states[7])
        feats.append(states[17] * states[0])
        
        feats.append(states[11] * states[20])
        feats.append(states[4] ** 2)              
        feats.append(states[16])
            
    elif feat_selection == 'manual':
        feats.append(states[0])
        feats.append(states[1] ** 2)
        feats.append(states[2])
        
        feats.append(states[4] ** 2)
        feats.append(states[13])
        feats.append(states[14])
        
        feats.append(states[15])
        feats.append(states[16] ** 2)
        feats.append(states[17] ** 2)
        
        feats.append(states[21])
        feats.append(states[24] * states[2])
        feats.append(states[23] * states[1])

    elif feat_selection == 'proposed':
        notex = cfg["path_to_basis"]
        if _precompiled_functions_cache is None:
            filepath = f'feature4irl/pickles/Ant-v4_basis_{notex}.joblib'
            _precompiled_functions_cache = load_and_precompile_functions_joblib(filepath)
        feats = compute_with_precompiled_functions(_precompiled_functions_cache, states)
        
    else:
        NotImplementedError()
        
    return np.array(feats)





# ██   ██  ██████  ██████  ██████  ███████ ██████  
# ██   ██ ██    ██ ██   ██ ██   ██ ██      ██   ██ 
# ███████ ██    ██ ██████  ██████  █████   ██████  
# ██   ██ ██    ██ ██      ██      ██      ██   ██ 
# ██   ██  ██████  ██      ██      ███████ ██   ██ 
                                                 
                                                 
 
######################################
##      HOPPER
######################################

def hopper_feat_extract(states, cfg):
    global _precompiled_functions_cache
    # features    
    use_norm = cfg["normalize_feats"]
     
    if use_norm:
        # read observaion values and normalize
        mu =  np.array([0.8881140999877914, -0.8519100246011, -0.38425375192232136, 0.08373274706761998, -0.007740304511785189, 0.017643511467072517, -0.008243891646439292, -0.010322091919607066])
        std = np.array([0.4504364359989499, 0.7309412231502657, 0.7025523126710245, 0.36405457628063975, 0.538587204240941, 0.7067707653749452, 1.2019269571584708, 1.1175795983828707])        
        states_std = (states - mu) / std
        
        # states_std = new_scaler.transform(states.reshape(1, -1)).squeeze()
        
    else:
        states_std = states
            
    # select subset
    feat_selection = cfg["feats_method"]
    feats = []
    
    curr_point = states_std
    # find features
    second_degree = [x * y for x, y in combinations_with_replacement(curr_point, 2)]
    # Combine both first and second degree polynomials
    feats_list = curr_point.tolist() + second_degree
    
    if feat_selection == 'first':
        feats.append(feats_list[0])
        feats.append(feats_list[1])    
        feats.append(feats_list[2])
        feats.append(feats_list[3])
        feats.append(feats_list[4])
        feats.append(feats_list[5])
        feats.append(feats_list[6])
        feats.append(feats_list[7])
        feats.append(feats_list[8])
        feats.append(feats_list[9])
        feats.append(feats_list[10])
                
    elif feat_selection == 'all':
        feats = feats_list
        
        
    elif feat_selection == 'random':        
        # 21, 7, 12, 19
        feats.append(feats_list[21])
        feats.append(feats_list[7])    
        feats.append(feats_list[12])
        feats.append(feats_list[25])
        
        feats.append(feats_list[5])
        feats.append(feats_list[1])
        feats.append(feats_list[15])
        
    elif feat_selection == 'manual':
        
        feats.append(states_std[8] ** 2) # pos
        feats.append(states_std[9] ** 2)
        feats.append(states_std[6] ** 2) # vel
        feats.append(states_std[7] ** 2)
        
        feats.append(states_std[8] * states_std[7])
        feats.append(states_std[6] * states_std[7])
        feats.append(states_std[9] * states_std[7])
    
        # manual
        # feats.append(feats_list[0])
        # feats.append(feats_list[1])    
        # feats.append(feats_list[2])
        # feats.append(feats_list[3])
        
        # feats.append(feats_list[4])
        # feats.append(feats_list[3])
        # feats.append(feats_list[4])
        
    elif feat_selection == 'proposed':
        notex = cfg["path_to_basis"]
        if _precompiled_functions_cache is None:
            filepath = f'feature4irl/pickles/Hopper-v4_basis_{notex}.joblib'
            _precompiled_functions_cache = load_and_precompile_functions_joblib(filepath)
        feats = compute_with_precompiled_functions(_precompiled_functions_cache, states_std)
        
    elif feat_selection == 'other':
        pass
    
    else:
        NotImplementedError()
        
        
    return np.array(feats)

