import numpy as np
import skopt
import torch
import random
import warnings
import pandas as pd

#from simulators.simulators2 import oracle_simulator
#from utils.active_learning import get_query



"""
def sample_dataset_from_simulator(args, search_space, n_samples, seed=None, stats=False):

    if seed is not None:
        np.random.seed(seed)

    # Sample training data
    x = np.array(sample_initial_inputs(n_samples, search_space, method=args.space_filling_design))
    #x = x.repeat(args.repeat_sampling, axis=0)  # *args.k_samples
    if args.simulator == "motorcycle":
        x = [np.where(search_space == int(tmp_x))[0][0] for tmp_x in x]

    # Get labels from oracle
    y, _, _ = oracle_simulator(args, x)

    x = torch.tensor(x, dtype=torch.float32) if not torch.is_tensor(x) else x
    if len(x.shape) == 1:
        x = x.view(-1, 1)
    y = torch.tensor(y, dtype=torch.float32) if not torch.is_tensor(y) else y

    return x, y
"""

def sample_initial_inputs(n_samples, search_space, method='uniform'):
    """
    Sample initial data points

    :param n_samples: the number of initial data points to sample
    :param search_space: search space to sample data points from
    :param method: method used for sampling
    :return: initial data points in X values and y values
    """

    # array for the points x' and respective f(x') that we sample using the acquisition function
    x_sample = []

    if method == 'uniform':
        # equally distributed
        unique_ss = np.unique(search_space, axis=0)

        #assert len(unique_ss) >= n_samples, f"Trying to get {n_samples} equidistant distributed points, but " \
        #                                   f"there only {len(unique_ss)} possibilities. Either decrease the " \
        #                                   f"initial_samples to below {len(unique_ss)} (unique data points) or " \
        #                                   f"increase the granularity of the search space."

        if len(unique_ss) < n_samples:
            warnings.warn(f"\nTrying to get {n_samples} equidistant distributed points, but there are only "
                          f"{len(unique_ss)} possibilities. Now we replicates this procedure "
                          f"{n_samples // len(unique_ss)} times and get "
                          f"{(n_samples // len(unique_ss)) * len(unique_ss)} data points.")
            x_sample = []
            for r in range(n_samples // len(unique_ss)):
                ids = [int(x) for x in np.linspace(0, len(unique_ss) - 1, len(unique_ss))]
                x_sample.extend(unique_ss[ids].tolist())
        else:
            x_samples = [int(x) for x in np.linspace(0, len(unique_ss) - 1, n_samples)]
            x_sample = unique_ss[x_samples].tolist()

    elif method == 'random':
        # if n_samples are smaller than the number of unique points, we get some distinct random samples
        # otherwise, just get some random samples
        n_unique = len(np.unique(search_space, axis=0))
        if n_samples < n_unique:
            unique_ss = np.unique(search_space, axis=0)
        else:
            unique_ss = search_space

        #n_samples = n_samples if unique_ss.shape[0] > n_samples else unique_ss.shape[0]
        replacement = True if n_samples > unique_ss.shape[0] else False
        i = np.random.choice(np.arange(unique_ss.shape[0]), size=n_samples, replace=replacement)
        #for _ in range(n_samples):
        #    i = np.random.randint(0, search_space.shape[0])
        #    xt = search_space[i]
        #    x_sample.append(xt)
        x_sample = unique_ss[i]

    elif method == 'lhs':
        # Use LHS from scikit-optimize
        # https://scikit-optimize.github.io/stable/modules/generated/skopt.sampler.Lhs.html#skopt.sampler.Lhs
        # Create 'space' from a list of tuples with (min, max) values for each feature

        """
        # Pseudo-population-based
        # We assume that the points in each dimension is equidistant !
        space_ints = []
        for feature in range(search_space.shape[1]):
            space_ints.append(skopt.space.Integer(low=0, high=len(np.unique(search_space[:, feature]))))
        space = skopt.space.Space(space_ints)
        lhs = skopt.sampler.Lhs(lhs_type="classic", criterion='maximin')
        x_sample_idx = np.array(lhs.generate(space.dimensions, n_samples))
        x_sample = np.empty(x_sample_idx.shape)
        for feature in range(search_space.shape[1]):
            tmp_values = np.unique(search_space[:, feature])[x_sample_idx[:, feature] - 1]
            x_sample[:, feature] = np.array(tmp_values, dtype="float64")
        #"""
        # Population-based:
        #"""
        if len(search_space.shape) == 1:
            space = [(np.min(search_space, axis=0), np.max(search_space, axis=0))]
        else:
            space = [tuple(x) for x in list(np.array([np.min(search_space, axis=0), np.max(search_space, axis=0)]).T)]
        space = skopt.space.Space(space)
        lhs = skopt.sampler.Lhs(lhs_type="classic", criterion='maximin')
        x_sample = lhs.generate(space.dimensions, n_samples)
        if len(search_space.shape) == 1:
            x_sample = np.array(x_sample).squeeze(-1)
        #"""

    elif method == 'sobol':
        # Use Sobol from scikit-optimize
        # https://scikit-optimize.github.io/stable/modules/generated/skopt.sampler.Lhs.html#skopt.sampler.Lhs
        # Create 'space' from a list of tuples with (min, max) values for each feature

        # Pseudo-population-based
        # We assume that the points in each dimension is equidistant !
        space_ints = []
        for feature in range(search_space.shape[1]):
            space_ints.append(skopt.space.Integer(low=0, high=len(np.unique(search_space[:, feature]))))
        space = skopt.space.Space(space_ints)
        sobol = skopt.sampler.Sobol()
        x_sample_idx = np.array(sobol.generate(space.dimensions, n_samples))
        x_sample = np.empty(x_sample_idx.shape)
        for feature in range(search_space.shape[1]):
            tmp_values = np.unique(search_space[:, feature])[x_sample_idx[:, feature] - 1]
            x_sample[:, feature] = np.array(tmp_values, dtype="float64")

        """
        space = [tuple(x) for x in list(np.array([np.min(search_space, axis=0), np.max(search_space, axis=0)]).T)]
        space = skopt.space.Space(space)
        sobol = skopt.sampler.Sobol()
        x_sample = sobol.generate(space.dimensions, n_samples)
        """
    return x_sample


def load_data(args, path):
    # Load Mercury data

    if args.dataset == "mercury1d":
        df = pd.read_pickle(path)
        search_space = np.array((df['alpha_tat_mean']))
    elif args.dataset == "mercury3d":
        df = pd.read_pickle(path)
        #search_space = np.array((df[['first_compensation_threshold', 'alpha_compensation_magnitude', 'fuel_price']]))
    elif args.dataset == "mercury3d-12":
        df = pd.read_pickle(path)
        search_space = np.array((df[['compensation_magnitude_long1', 'first_compensation_threshold', 'fuel_price']]))

    elif args.dataset == "autompg":
        df = pd.read_csv(path + '/UCI_Datasets/auto-mpg.csv', header=None)
        df.columns = ['mpg', 'cyl', 'dis', 'hp', 'weight', 'acc', 'year', 'origin', 'name']
        df = df[df['hp'] != '?']  # Remove rows with no 'horsepower'
        df['hp'] = df.hp.astype(float)  # Convert horsepower to float
        dummy_df = pd.get_dummies(df['origin'])  # One-hot-encoding of the feature 'origin'
        dummy_df.columns = ['origin_1', 'origin_2', 'origin_3']
        df = pd.concat([df, dummy_df], axis=1)
        del df['origin'], df['name']

        # Search space
        search_space = df.iloc[:, 1:-1].values
        # Oracle labels
        oracle_labels = df['mpg'].values
    else:
        raise NotImplementedError(f"The dataset {args.dataset} is not available. Specify another dataset with "
                                  f"args.dataset.")

    if args.dataset in ['mercury1d', 'mercury3d', 'mercury3d-12']:
        if args.outputs == 1:
            oracle_labels = np.array(df['fuel_cost_m3_mean'])
            #oracle_labels = np.array(df['arrival_delay_min_mean'])
            #oracle_labels = np.array(df['total_cost_mean'])
        elif args.outputs == 4:
            oracle_labels = np.stack((np.array(df['arrival_delay_min_mean']),
                                      np.array(df['arrival_delay_min_std']),
                                      np.array(df['departure_delay_min_mean']),
                                      np.array(df['departure_delay_min_std'])), axis=1)
        elif args.outputs == 6:
            oracle_labels = np.stack((np.array(df['arrival_delay_min_mean']),
                                      np.array(df['departure_delay_min_mean']),
                                      np.array(df['total_cost_mean']),
                                      np.array(df['pax_tot_arrival_delay_mean']),
                                      np.array(df['lcc_arrival_delay_min_mean']),
                                      np.array(df['lcc_total_cost_mean'])), axis=1)
        else:
            oracle_labels = []
            raise NotImplementedError("\n----> Go to utils/data_handler.py to set inputs\n")

    return search_space, oracle_labels
