import numpy as np
import matplotlib.pyplot as plt
from sklearn.feature_selection import f_regression, mutual_info_regression
from sklearn.model_selection import train_test_split
from sklearn import datasets, linear_model
from sklearn.metrics import r2_score
from helper import *
from helper_basis_symbolic import *
from utils_feature_create import *
import os
from joblib import Parallel, delayed
import time
import numpy as np
import ot

np.random.seed(42)  # Fix the seed for reproducibility

def reverse_range(y):
    y_min = np.min(y)
    y_max = np.max(y)
    y_reversed = y_min + y_max - y
    return y_reversed


def create_basis(
    env_name,
    notex,
    add_non_expert,
    normalize,
    use_marginal,
    use_ortho,
    use_fourier,
    use_rbf,
    remove_outliers,
    verbose,
    num_chunks,
    dims,
    score_T,
    gamma, 
    n_selected,
    drop_feats,
    save, 
    method,
    data_path='none',
    norm_type='none'   
):
    folder_path = f'tmp/{env_name}/{notex}/'
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    
    trajs, rewards, expert_ts, non_expert_trajs, non_expert_ts = load_data(env_name, normalize, num_chunks, gamma)
    trajs, means, stds = normalize_trajs_std(trajs, norm_type=norm_type)
    print('trajs', len(trajs))
    non_expert_trajs, _, _ = normalize_trajs_std(non_expert_trajs, norm_type=norm_type)
    
    try:
        logP_tau = np.load(f'tmp/{env_name}/logP_tau_{data_path}_{num_chunks}.npy')
        print('Loaded logP_tau from file')
    except:
        pass
        print('...Computing logP_tau')
        kde_states, kde_succ_states = fit_kde(trajs)
        logP_tau = compute_log_probabilities(trajs, expert_ts, num_chunks, kde_states, kde_succ_states, use_marginal=use_marginal)
        np.save(f'tmp/{env_name}/logP_tau_{data_path}_{num_chunks}.npy', logP_tau)
    
    all_basis_functions, all_variables = create_all_symbolic_basis_functions(
        dims=dims,
        use_fourier=use_fourier,
        use_rbf=use_rbf,
        use_ortho=use_ortho, 
        means=means,
        stds=stds        
        )
    
    feats = compute_numerical_values_for_all_trajs_optimized(all_basis_functions, all_variables, trajs)
    if verbose:
        save_array_plot(logP_tau, f'{folder_path}/logP_tau.png')
        
    mu_tau = find_feature_trajs(len(all_basis_functions), feats, gamma=gamma)
    
    mu_tau_expert, logP_tau_expert = filter_outliers(mu_tau, logP_tau, remove_outliers)
    
    if verbose:
        save_array_plot(logP_tau_expert, f'{folder_path}/logP_tau_sorted.png')

    if add_non_expert:
        feats = compute_numerical_values_for_all_trajs_optimized(all_basis_functions, all_variables, non_expert_trajs)
        mu_tau_agent = find_feature_trajs(len(all_basis_functions), feats, gamma=gamma)
        
        # TODO add other probs
        n_agent = 80 
        logP_tau_agent = np.random.uniform(min(logP_tau_expert), min(logP_tau_expert)*(1.1), n_agent)
        
        mu_tau_agent = mu_tau_agent[:len(logP_tau_agent)]
        mu_tau_combined = np.vstack([mu_tau_expert, mu_tau_agent])
        logP_tau_combined = np.hstack([logP_tau_expert, logP_tau_agent])
    else:
        mu_tau_combined = mu_tau_expert
        logP_tau_combined = logP_tau_expert
    
    select_top_features(mu_tau_combined, 
                        logP_tau_combined, 
                        mu_tau_expert, 
                        logP_tau_expert,
                        all_variables,
                        all_basis_functions, 
                        score_T, n_selected=n_selected, 
                        drop_features=drop_feats, 
                        env_name=env_name, 
                        notex=notex, 
                        save=save, 
                        method=method, 
                        verbose=verbose, 
                        folder_path=None
                        )
    


def normalize_trajs_std(trajs, norm_type='none'):
    # Stack trajectories into a single array
    stacked_trajs = np.vstack(trajs)
    
    if norm_type == 'none':
        normalized_trajs = stacked_trajs
    elif norm_type == 'std':
        # Initialize and fit the StandardScaler
        sc = StandardScaler()
        stacked_trajs = sc.fit_transform(stacked_trajs)
        normalized_trajs = stacked_trajs
    elif norm_type == 'max':
        sc = MinMaxScaler((-1, 1))
        stacked_trajs = sc.fit_transform(stacked_trajs)
        normalized_trajs = stacked_trajs
    else:
        raise ValueError(f'Invalid norm_type: {norm_type}')
        
    # Reshape normalized data back into the original list structure
    normalized_trajs_list = []
    start_idx = 0
    for traj in trajs:
        end_idx = start_idx + len(traj)
        normalized_trajs_list.append(normalized_trajs[start_idx:end_idx])
        start_idx = end_idx
    
    return normalized_trajs_list, None, None


if __name__ == '__main__':
    
    # for num_chunks in [50, 100, 150, 200, 250]:
    num_chunks = 150
    
    for env_name, dims, score_T, remove_outliers, n_selected in [
        ("HalfCheetah-v4", 17, 0.6, False, 12),
        ("Walker2d-v4", 17, 0.7, True),
        ("Hopper-v4", 11, 0.6, False, 10),  
        ("Ant-v4", 27, 0.6, False, 20),
    ]:
        
        use_marginal=True
        norm_type = 'max'    # 'std', 'max', 'none'
        bwn = 'one'          # 'one', 'many' 
        
        data_path = 'B'
        
        if use_marginal: 
            data_path += 'm'
        else: 
            data_path += 's'

        if norm_type == 'std':
            data_path += 's'
        elif norm_type == 'max':
            data_path += 'm'
        elif norm_type == 'none':
            data_path += 'n'
        else:
            raise ValueError(f'Invalid norm_type: {norm_type}')
            
        if bwn == 'one':
            data_path += 's'
        else:
            data_path += 'm'        
        
        method = 1
        drop_feats = 0

        
        normalize = True
        verbose = True
        save=True
        gamma = 1
        
        add_non_expert = False
        use_fourier = False
        use_rbf = False
        use_ortho = False
        
        notex = data_path + '10'
        if add_non_expert:
            notex += '_ne2'

        print(f'\n{env_name}---------{data_path}--------{notex}-----------{method}')
        create_basis(
            env_name=env_name, 
            notex=notex, 
            add_non_expert=add_non_expert, 
            normalize=normalize, 
            use_marginal=use_marginal, 
            use_ortho=use_ortho, 
            use_fourier=use_fourier,
            use_rbf=use_rbf,
            remove_outliers=remove_outliers,
            verbose=verbose, 
            num_chunks=num_chunks,
            dims=dims,
            score_T=score_T,
            gamma=gamma,
            n_selected=n_selected, 
            drop_feats=drop_feats,
            method=method,
            save=save, 
            data_path=data_path, 
            norm_type=norm_type
        )