import pandas as pd
import numpy as np
from functools import wraps
import operator

phi = (1 + np.sqrt(5))/2

def fib(n):
    return np.round((np.power(phi, n) - 1/np.power(-phi, n))/np.sqrt(5)).astype(np.int_)

def with_default_prng(sampler):
    @wraps(sampler)
    def sampler_with_default_prng(self, size, prng=None):
        prng = np.random.default_rng(prng)
        return sampler(self, size, prng)
    
    return sampler_with_default_prng


class SampledParam:
    pass

class UniformInSequenceParam(SampledParam):
    def __init__(self, start_index, stop_index):
        self.start = start_index
        self.stop = stop_index
        
    @with_default_prng
    def sample(self, size, prng):
        i = prng.integers(self.start, self.stop, size=size)
        return self.sequence(i)

class UniformInt(UniformInSequenceParam):
    def sequence(self, i):
        return i

class UniformCategorical(SampledParam):
    def __init__(self, values):
        self.values = np.asarray(values)
        
    @with_default_prng
    def sample(self, size, prng):
        i = prng.integers(len(self.values), size=size)
        return self.values[i]
    
class OrZero(SampledParam):
    def __init__(self, distribution, pzero=0.5):
        self.distribution = distribution
        self.pzero = pzero
        
    @with_default_prng
    def sample(self, size, prng):
        i = prng.uniform(size=size) > self.pzero
        nonzero = self.distribution.sample(np.sum(i), prng=prng)
        
        output = np.zeros(size, dtype=nonzero.dtype)
        output[i] = nonzero
        return output
    
class OrValue(SampledParam):
    def __init__(self, distribution, value, pvalue=0.5):
        self.distribution = distribution
        self.value = value
        self.pvalue = pvalue
        
    @with_default_prng
    def sample(self, size, prng):
        i = prng.uniform(size=size) > self.pvalue
        nonzero = self.distribution.sample(np.sum(i), prng=prng)
        
        output = np.full(size, self.value, dtype=nonzero.dtype)
        output[i] = nonzero
        return output

class UniformFibonacci(UniformInSequenceParam):
    def sequence(self, i):
        return fib(i)
    
class UniformPower(UniformInSequenceParam):
    def __init__(self, start_index, stop_index, base=2):
        super().__init__(start_index, stop_index)
        self.base = base
        
    def sequence(self, i):
        return np.power(self.base, i).astype(np.int_)

class LowerAndUpperBoundedParam(SampledParam):
    def __init__(self, lb, ub):
        self.lb = lb
        self.ub = ub
    
class Uniform(LowerAndUpperBoundedParam):
    @with_default_prng
    def sample(self, size, prng):
        return prng.uniform(self.lb, self.ub, size=size)

class IntUniform(LowerAndUpperBoundedParam):
    @with_default_prng
    def sample(self, size, prng):
        return prng.integers(self.lb, self.ub, size=size)
    
class LogUniform(LowerAndUpperBoundedParam):
    @with_default_prng
    def sample(self, size, prng):
        return np.exp(prng.uniform(np.log(self.lb), np.log(self.ub), size=size))

class IntLogUniform(LogUniform):
    @with_default_prng
    def sample(self, size, prng):
        return super().sample(size, prng).astype(np.int_)


class GridParams:
    def __init__(self, space):                        
        self.static = {k: v for k, v in space.items() if np.isscalar(v)}
        self.dynamic = {k: v for k, v in space.items() if not np.isscalar(v)}
        dynamic_values = [v.flatten() for v in np.meshgrid(*self.dynamic.values())]
        self.dynamic = dict(zip(self.dynamic.keys(), dynamic_values))
        self.size = dynamic_values[0].size
        
                
    def __getitem__(self, i):
        params = {k: v[i] for k, v in self.dynamic.items()}
        params.update(self.static)
        return params

    def to_pandas(self):
        return pd.DataFrame(self.dynamic)

class RandomParams:
    def __init__(self, space, size, prng=None):            
        prng = np.random.default_rng(prng)
            
        self.size = size
        self.dynamic = {}
        self.static = {}
        for k, v in space.items():
            if isinstance(v, SampledParam):
                self.dynamic[k] = v.sample(size, prng)
            else:
                self.static[k] = v
                
    def __getitem__(self, i):
        params = {k: v[i] for k, v in self.dynamic.items()}
        params.update(self.static)
        return params

    def to_pandas(self):
        return pd.DataFrame(self.dynamic)


def line_search(points, direction='min'):
    if direction=='min':
        agg, best = min, np.inf
    elif direction=='max':
        agg, best = max, -np.inf
    else:
        raise ValueError(f'Only "max" and "min" are valid values for direction. Got {direction}.')
        
    for point in points:
        val = yield point
        if val is not None:
            best = agg(val, best)
    
    return best
    

def early_stopping_line_search2(points, direction='min'):
    if direction=='min':
        comparison = operator.gt
        best = np.inf
    elif direction=='max':
        comparison = operator.lt
        best = -np.inf
    else:
        raise ValueError(f'Only "max" and "min" are valid values for direction. Got {direction}.')
        
    for point in points:
        val = yield point
        if comparison(val, best):
            break
        
        best = val
    
    return best
        

def early_stopping_line_search(points, patience=1, direction='min'):
    if direction=='min':
        comparison = operator.gt
        best = np.inf
        fails = 0
    elif direction=='max':
        comparison = operator.lt
        best = -np.inf
        fails = 0
    else:
        raise ValueError(f'Only "max" and "min" are valid values for direction. Got {direction}.')
        
    for point in points:
        val = yield point
        if comparison(val, best):
            fails += 1
            if fails >= patience:
                break
        else:
            fails = 0
            best = val
    
    return best

        
r = (np.sqrt(5) - 1)/2
#r2 = 1 - r
def golden_search(a, b, num_evals=5, direction='min'):
    if direction=='min':
        comparison = operator.lt
    elif direction=='max':
        comparison = operator.gt
    else:
        raise ValueError(f'Only "max" and "min" are valid values for direction. Got {direction}.')
    
    a, b = min(a, b), max(a, b)
    h = b - a
    
    c = b - r*h
    d = a + r*h
    
    yc = yield c
    yd = yield d
    
    for k in range(num_evals - 2):
        h *= r
        if comparison(yc, yd):
            b, d = d, c
            c = b - r*h
            yd = yc
            yc = yield c
        else:
            a, c = c, d
            d = a + r*h
            yc = yd
            yd = yield d
            
    return min(yc, yd) if direction == 'min' else max(yc, yd)


def in_log_space(search):
    def generator(*args, **kwargs):
        gen = search(*[np.log(arg) for arg in args], **kwargs)
            
        try:
            value = yield np.exp(next(gen))
            while True:
                value = yield np.exp(gen.send(value))
        except StopIteration as e:
            return e.value
    return generator

    
class ForParam:
    def __init__(self, generator_function, *args, **kwargs):
        self.generator_function = generator_function
        self.args = args
        self.kwargs = kwargs
        
        
    def fval(self, val):
        self.value = val
    
    def __next__(self):
        try:
            if self.value is None:
                return next(self.generator)
            
            return self.generator.send(self.value)
        except StopIteration as e: # intercept best value
            self.best = e.value
            raise StopIteration
        
    
    def __iter__(self):
        self.value = None
        self.best = None
        self.generator = self.generator_function(*self.args, **self.kwargs)
        return self
    
    
class LineSearch(ForParam):
    def __init__(self, points, direction='min'):
        return super().__init__(line_search, points, direction=direction)


class EarlyStoppingLineSearch(ForParam):
    def __init__(self, points, patience=1, direction='min'):
        return super().__init__(early_stopping_line_search, points, patience=patience, direction=direction)
    

class GoldenSearch(ForParam):
    def __init__(self, a, b, num_evals=5, log=False, direction='min'):
        generator = golden_search
        if log:
            generator = in_log_space(generator)
        
        return super().__init__(generator, a, b, num_evals=num_evals, direction=direction)
    