import numpy as np
import pandas as pd

def random_borehole(n):
    rw = np.random.normal(loc=0.10, scale=0.0161812, size=(n,1))
    r  = np.random.lognormal(mean=7.71, sigma=1.0056, size=(n,1))
    Tu = np.random.uniform(63070, 115600, size=(n,1))
    Hu = np.random.uniform(990, 1110, size=(n,1))
    Tl = np.random.uniform(63.1, 116, size=(n,1))
    Hl = np.random.uniform(700, 820, size=(n,1))
    L  = np.random.uniform(1120, 1680, size=(n,1))
    Kw  = np.random.uniform(9855, 12045, size=(n,1))
    return np.concatenate((rw,r,Tu,Hu,Tl,Hl,L,Kw), axis=1)

class BenchmarkDataset:
    def __init__(self, dataset, dataset_config, seed=None):
        self.dataset_name = dataset
        self.seed = seed
        self.n_train = dataset_config[dataset]["n_train"]
        self.n_valid = dataset_config[dataset]["n_valid"]
        self.n_test  = dataset_config[dataset]["n_test"]
        self.dim = dataset_config[dataset]["dim"]
        self.min_value = dataset_config[dataset]["min_value"]
        self.max_value = dataset_config[dataset]["max_value"]

        if "path" in dataset_config[dataset]:
            self.path = dataset_config[dataset]["path"]
            self.data = pd.read_csv(self.path)
            features = dataset_config[dataset]["features"]
            target = dataset_config[dataset]["target"]
            X = self.data[features].values.astype(np.float32)
            y = self.data[target].values.astype(np.float32)
            if "train_frac" in dataset_config[dataset]:
                n_data = self.data.shape[0]
                self.n_train = np.int(n_data * dataset_config[dataset]["train_frac"])
                self.n_valid = np.int(n_data * dataset_config[dataset]["valid_frac"])
                self.n_test = n_data - self.n_train - self.n_valid
            idx = np.arange(len(X))
            if self.seed is not None:
                np.random.seed(self.seed)
            np.random.shuffle(idx)
            self.X_train = X[idx[:self.n_train]]
            self.y_train = y[idx[:self.n_train]]
            self.X_valid = X[idx[self.n_train:self.n_train+self.n_valid]]
            self.y_valid = y[idx[self.n_train:self.n_train+self.n_valid]]
            self.X_test  = X[idx[self.n_train+self.n_valid:self.n_train+self.n_valid+self.n_test]]
            self.y_test  = y[idx[self.n_train+self.n_valid:self.n_train+self.n_valid+self.n_test]]
        else:
            self.path = None
            self.f = getattr(self, self.dataset_name)
            self._prepare_data()



    def _prepare_data(self):
        if self.seed is not None:
            np.random.seed(self.seed)
            
        # maybe uniform is not the best idea
        if self.dataset_name == 'borehole':
            self.X_train = random_borehole(self.n_train)
            self.X_valid = random_borehole(self.n_valid)
            self.X_test  = random_borehole(self.n_test)
        else:
            self.X_train = np.random.uniform(self.min_value, self.max_value, size=(self.n_train, self.dim)).astype(np.float32)
            self.X_valid = np.random.uniform(self.min_value, self.max_value, size=(self.n_valid, self.dim)).astype(np.float32)
            self.X_test  = np.random.uniform(self.min_value, self.max_value, size=(self.n_test, self.dim)).astype(np.float32)
        
        self.y_train = self.f(self.X_train)
        self.y_valid = self.f(self.X_valid)
        self.y_test  = self.f(self.X_test)
    
    def get_data(self):
        return (self.X_train, self.y_train), (self.X_valid, self.y_valid), (self.X_test, self.y_test)
    
    @staticmethod
    def griewank(x):
        S = np.sqrt(np.arange(1,x.shape[-1]+1))[None]
        y = np.sum(x**2/4000., axis=1) - np.prod(np.cos( x /  S ), axis=1) + 1
        return y.astype(np.float32)

    @staticmethod
    def levy(x):
        d  = x.shape[-1]
        x  = np.array(x).astype(np.float32)
        w  = 1. + (x-1.)/4
        t1 = np.sin(np.pi * w[:, 0])
        t2 = (w[:,:-1] - 1.)**2 * (1. + 10.*(np.sin(np.pi*w[:,:-1]+1))**2)
        t3 = (w[:, -1] - 1.)**2 * (1+ (np.sin(2*np.pi*w[:,-1]))**2)

        y = t1 + t2.sum(axis=1) + t3
        return y.astype(np.float32)
   
    @staticmethod
    def borehole(x):
        assert x.shape[-1] == 8
        num = 5*x[:,2]*(x[:,3]-x[:,5])
        ln_r_rw = np.log(x[:,1]/x[:,0])
        den = ln_r_rw * (1.5 + \
              (  2*x[:,6]*x[:,2] / (ln_r_rw * x[:,0]**2*x[:,7])) \
              + x[:,2]/x[:,4])
        return (num / den).astype(np.float32)
