import numpy as np

def init_newsvendor_params():
    params = {}

    # Ordering costs
    params['c_lin'] = 10
    # params['c_quad'] = 1e-2
    params['c_quad'] = 1e-2

    # Over-order penalties
    params['b_lin'] = 30
    # params['b_quad'] = 14
    params['b_quad'] = 1e-2

    # Under-order penalties
    params['h_lin'] = 10
    params['h_quad'] = 1e-2
    # params['h_quad'] = 2

    # Discrete demands
    params['d'] = np.array([1, 2, 5, 10, 20]).astype(np.float32)

    # Number of features
    params['n'] = 20

    return params

def init_theta_true(params, is_linear=True, with_seed=False):
    if is_linear:
        # Linear true model (py ∝ exp(θX))
        np.random.seed(42) if with_seed else np.random.seed(None)
        Theta_true_lin = np.random.randn(params['n'], len(params['d']))
        Theta_true_sq = np.zeros((params['n'], len(params['d'])))
    else:
        # Squared true model (py ∝ exp((θX)^2))
        Theta_true_lin = np.zeros((params['n'], len(params['d'])))
        np.random.seed(42) if with_seed else np.random.seed(None)
        Theta_true_sq = np.random.randn(params['n'], len(params['d']))

    np.random.seed(None)

    return Theta_true_lin, Theta_true_sq


def gen_data(m, params, Theta_true_lin, Theta_true_sq, with_seed=False):
    np.random.seed(0) if with_seed else np.random.seed(None)
    X  = np.random.randn(m, params['n'])

    PY = np.exp(X.dot(Theta_true_lin) + (X.dot(Theta_true_sq)) ** 2)
    PY = PY / np.sum(PY, axis=1)[:, None]

    # Generate demand realizations
    Y  = np.where(np.cumsum(np.random.rand(m)[:, None]
                < np.cumsum(PY, axis=1), axis=1) == 1)[1]
    Y  = np.eye(len(params['d']))[Y, :]
    
    np.random.seed(None)

    return X, Y

def log_error_and_write(e, save_folder, m, run, model, results_file, newline=False):
    with open(os.path.join(save_folder, 'errors.log'), 'a') as f:
        f.write('{}: m {}, model {}, run {}: {}\n'.format(
            datetime.now(), m, model, run, e))
    with open(results_file, 'a') as f:
        f.write('\n' if newline else ',')