import numpy as np
import numpy.linalg as linalg
import scipy
import scipy.stats

import dill as pickle

class PopulationBanditsProblem:
    def __init__(self, operator, problem_parameters):
        self.problem_parameters = problem_parameters
        self.operator = operator

    def __call__(self, x):
        return self.operator(x)

def new_pd_matrix(actions, scale1, scale2, df, seed=0):
    """produce random linear psd matrix"""

    w = scipy.stats.wishart(df=df, scale=scale1*np.eye(actions), seed=seed)
    A = w.rvs()
    
    B = np.random.rand(actions*actions).reshape((actions, actions)) * scale2
    B = B - B.T
    
    return (-A + B)


def get_kl_operator(reference_policy, gamma=0.1):
    return lambda x: -np.log(((1-gamma) * x + gamma * reference_policy) / reference_policy) - 1


def get_exp_potential_rewards(alpha=None):
    if alpha is None:
        return lambda x: np.exp(-x)
    else:
        return lambda x: np.exp(-alpha*x)


def tikhonov_regularize_operator(operator, tau):
    return lambda x: operator(x) - tau * x


## Methods for randomly generating VI problems

def generate_new_exp_problem(K, seed):
    np.random.seed(seed)
    alpha = np.random.rand(K) + 0.01
    operator_exp_monotone = get_exp_potential_rewards(alpha)

    return PopulationBanditsProblem(operator_exp_monotone, {"alpha": alpha})


def generate_new_kl_problem(K, seed):
    np.random.seed(seed)

    reference_policy = np.random.rand(K) + 0.1
    gamma = 0.1
    reference_policy = reference_policy / reference_policy.sum()
    operator_kl = get_kl_operator(reference_policy, gamma=gamma)

    return PopulationBanditsProblem(operator_kl, {"reference_policy": reference_policy, "gamma":gamma})


def generate_new_linear_problem(K, seed):
    A = new_pd_matrix(K, scale1=2, scale2=4, df=K, seed=seed) / K
    b = np.random.rand(K) + 1.
    operator_linear = lambda x: A @ x + b

    return PopulationBanditsProblem(operator_linear, {"A": A, "b": b})


def generate_beach_bar_problem(K, seed):
    alpha = 1.0
    a = K // 2
    xx = np.arange(K)
    operator_bb = lambda x: 1-np.abs(xx - a)/K - alpha * (np.log(1 + x))
    return PopulationBanditsProblem(operator_bb, {"K": K, "a": a, "alpha":alpha})


def load_utd_problem(file_name):
    with open(file_name, 'rb') as f: 
        data = pickle.load(f)
        
    models = data["fitted_models"]
    distances = data["distances"]
    capacity = data["capacity"]
    sensors = data["sensors"]
    
    operator_utd = lambda x: np.nan_to_num(np.array([(m.predict(np.array(x[i]).reshape(-1,1)).reshape(-1)[0] / x[i] / c) / d 
                                                    for i,(m,d,c) in enumerate(zip(models, distances, capacity))]),
                                           nan=0., posinf=2., neginf=0.)
    
    return PopulationBanditsProblem(operator_utd, {"data":data})
