import logging
import numpy as np
from functools import reduce


def solve_assortment_and_pricing(K, alpha, beta):
    B = 1
    assortment = np.array([])
    B_min = 0
    B_max = 1000
    while B_max - B_min > 1e-6:
        B = (B_max + B_min)/2
        v_of_B = np.exp(alpha - beta * B - 1) / beta
        assortment = np.argpartition(v_of_B, -K)[-K:]
        B_achieved = np.sum(v_of_B[assortment])
        if B_achieved < B:
            B_max = B
        else:
            B_min = B

    prices = 1 / beta + B
    return assortment, prices


def grad_reduce(input_grad, sum_grad):
    return sum_grad[0] + input_grad[0], sum_grad[1] + input_grad[1]


def solve_mle(offered_contexts, selected_contexts, d, init_theta=None):
    if init_theta is None or np.isnan(init_theta).any():
        theta = 0.5 * np.ones(d)
    else:
        theta = init_theta
    iter = 0
    while True:
        def grad_map(input_contexts):
            offered_contexts, selected_contexts = input_contexts
            utilities = offered_contexts @ theta
            terms = np.exp(utilities)
            probs = terms / (1 + np.sum(terms))
            grad_new = - np.sum((probs * offered_contexts.T).T, axis=0) + selected_contexts
            W = np.diag(probs) - np.outer(probs, probs)
            hess_new = offered_contexts.T @ W @ offered_contexts
            return grad_new, hess_new

        results = map(grad_map, zip(offered_contexts, selected_contexts))
        grad, hess = reduce(grad_reduce, results)

        iter += 1
        update = np.linalg.inv(hess) @ grad
        theta = theta + update
        if np.linalg.norm(update) < 1e-5 or iter > 5:
            break

    X_data = np.concatenate(offered_contexts, axis=0)

    cov = X_data.T @ X_data / len(X_data)
    eigvals = np.linalg.eigvals(cov)
    # print(f"Data Covariance:\n{cov}")

    # clf = LogisticRegression(random_state=0, fit_intercept=False, penalty='l2').fit(
    #     X_data, [0 if np.sum(selected_contexts[tau]) == 0 else 1 for tau in range(len(offered_contexts))])
    # print("clf_coef_ = ", clf.coef_)

    return theta


def mle_online_update(theta, V, offered_contexts, selected_contexts, kappa):
    utilities = offered_contexts @ theta
    terms = np.exp(utilities)
    probs = terms / (1 + np.sum(terms))
    G = - np.sum((probs * offered_contexts.T).T, axis=0) + selected_contexts
    theta = np.linalg.pinv(V) @ (V @ theta + G / kappa)
    return theta


class OfflineOptimalAlgorithm:
    def __init__(self, K):
        self.K = K

    def get_assortment_and_pricing(self, alpha_star, beta_star):
        return solve_assortment_and_pricing(self.K, alpha_star, beta_star)


class DynamicAlgorithms:
    def __init__(self, n, d, K, L0, T0):
        self.K = K
        self.n = n
        self.d = d
        self.theta = np.zeros(2 * d)
        self.alpha_g = 0.003
        self.offered_contexts = []
        self.selected_contexts = []
        self.t = 0
        self.L0 = L0
        self.T0 = T0
        self.V = np.zeros((2 * d, 2 * d))

    def get_assortment_and_pricing(self, contexts):
        if self.t < self.T0:
            assortment = np.random.choice(self.n, size=self.K, replace=False)
            prices = np.random.choice(2, size=self.n)
        else:
            psi, phi = self.theta[:self.d], self.theta[self.d:]
            contexts_double = np.concatenate([contexts, contexts], axis=1)
            g = np.zeros(self.n)
            V_inv = np.linalg.inv(self.V)
            for i in range(self.n):
                g[i] = self.alpha_g * np.inner(V_inv @ contexts_double[i], contexts_double[i])
            alpha = np.minimum(contexts @ psi + g, 1)
            beta = np.maximum(contexts @ phi - g, self.L0)
            assortment, prices = solve_assortment_and_pricing(self.K, alpha, beta)
        return assortment, prices


class DynamicAssortmentPricing(DynamicAlgorithms):
    def selection_feedback(self, i_t, contexts, assortment, prices):
        x_tilde = np.concatenate([contexts[assortment], (- prices[assortment] * contexts[assortment].T).T], axis=1)
        self.V += x_tilde.T @ x_tilde
        self.offered_contexts.append(x_tilde)
        if i_t is not None:
            self.selected_contexts.append(np.concatenate([contexts[i_t], - prices[i_t] * contexts[i_t]]))
        else:
            self.selected_contexts.append(np.zeros(2 * self.d))
        if self.t >= self.T0:
            self.theta = solve_mle(self.offered_contexts, self.selected_contexts, 2 * self.d, init_theta=self.theta)
        self.t += 1


class NewtonAssortmentPricing(DynamicAlgorithms):

    def __init__(self, n, d, K, L0, T0):
        self.kappa = 0.03
        super().__init__(n, d, K, L0, T0)

    def selection_feedback(self, i_t, contexts, assortment, prices):
        x_tilde = np.concatenate([contexts[assortment], (- prices[assortment] * contexts[assortment].T).T], axis=1)
        self.V += x_tilde.T @ x_tilde
        if self.t < self.T0:
            self.offered_contexts.append(x_tilde)
            if i_t is not None:
                self.selected_contexts.append(np.concatenate([contexts[i_t], - prices[i_t] * contexts[i_t]]))
            else:
                self.selected_contexts.append(np.zeros(2 * self.d))
        elif self.t == self.T0:
            self.theta = solve_mle(self.offered_contexts, self.selected_contexts, 2 * self.d, init_theta=self.theta)
        else:
            offered_contexts = x_tilde
            if i_t is not None:
                selected_contexts = np.concatenate([contexts[i_t], - prices[i_t] * contexts[i_t]])
            else:
                selected_contexts = np.zeros(2 * self.d)
            self.theta = mle_online_update(self.theta, self.V, offered_contexts, selected_contexts, self.kappa)
        self.t += 1


class JavanmardDynamicPricing:
    def __init__(self, n, d, K, L0):
        self.K = K
        self.n = n
        self.d = d
        self.theta = np.zeros(2 * d)
        self.alpha_g = 0.003
        self.offered_contexts = []
        self.selected_contexts = []
        self.t = 0
        self.L0 = L0
        self.episode_len = d
        self.episode_t = 0
        self.V = np.zeros((2 * d, 2 * d))

    def get_assortment_and_pricing(self, contexts):
        if self.episode_t < self.d:
            assortment = np.random.choice(self.n, size=self.K, replace=False)
            prices = np.random.choice(2, size=self.n)
        else:
            psi, phi = self.theta[:self.d], self.theta[self.d:]
            alpha = np.minimum(contexts @ psi, 1)
            beta = np.maximum(contexts @ phi, self.L0)
            # compute prices assuming all items can be offered
            _, prices = solve_assortment_and_pricing(self.n, alpha, beta)
            # choose best up-to-K items that result in the largest expected revenue
            values = alpha - beta @ prices
            assortment = []
            best_exp_rev = 0
            for K_counter in range(1, self.K+1):
                assortment_k = np.argpartition(values, -K_counter)[-K_counter:]
                expected_revenue_k = np.sum(prices[assortment_k] * np.exp(values[assortment_k])) \
                                   / (1 + np.sum(np.exp(values[assortment_k])))
                if expected_revenue_k > best_exp_rev:
                    assortment = assortment_k
                    best_exp_rev = expected_revenue_k
            # logging.debug(assortment)
            # logging.debug(prices)
        return assortment, prices

    def selection_feedback(self, i_t, contexts, assortment, prices):
        if self.episode_t < self.d:
            x_tilde = np.concatenate([contexts[assortment], (- prices[assortment] * contexts[assortment].T).T], axis=1)
            self.V += x_tilde.T @ x_tilde
            self.offered_contexts.append(x_tilde)
            if i_t is not None:
                self.selected_contexts.append(np.concatenate([contexts[i_t], - prices[i_t] * contexts[i_t]]))
            else:
                self.selected_contexts.append(np.zeros(2 * self.d))
        if self.episode_t == self.d - 1:
            self.theta = solve_mle(self.offered_contexts, self.selected_contexts, 2 * self.d, init_theta=self.theta)
        self.t += 1
        self.episode_t += 1
        # logging.debug(len(self.selected_contexts))
        if self.episode_t == self.episode_len:
            self.t = 0
            self.episode_t = 0
            self.episode_len += 1


class OhIyengarAssortmentSelection:
    def __init__(self, n, d, K, L0, T0, fixed_prices):
        self.K = K
        self.n = n
        self.d = d
        self.theta = np.zeros(d)
        self.alpha_g = 0.003
        self.offered_contexts = []
        self.selected_contexts = []
        self.t = 0
        self.L0 = L0
        self.T0 = T0
        self.V = np.zeros((d, d))
        self.fixed_prices = fixed_prices

    def get_assortment_and_pricing(self, contexts):
        if self.t < self.T0:
            assortment = np.random.choice(self.n, size=self.K, replace=False)
        else:
            g = np.zeros(self.n)
            V_inv = np.linalg.inv(self.V)
            for i in range(self.n):
                g[i] = self.alpha_g * np.inner(V_inv @ contexts[i], contexts[i])
            values = np.minimum(contexts @ self.theta + g, 1 + self.fixed_prices)
            assortment = []
            best_prob_selection = 0
            for K_counter in range(1, self.K + 1):
                assortment_k = np.argpartition(values, -K_counter)[-K_counter:]
                probability_of_selection_k = 1 - 1 / (1 + np.sum(np.exp(values[assortment_k])))
                if probability_of_selection_k > best_prob_selection:
                    assortment = assortment_k
                    best_prob_selection = probability_of_selection_k
        return assortment, self.fixed_prices * np.ones(self.n)

    def selection_feedback(self, i_t, contexts, assortment, prices):
        x_tilde = contexts[assortment]
        self.V += x_tilde.T @ x_tilde
        self.offered_contexts.append(x_tilde)
        if i_t is not None:
            self.selected_contexts.append(contexts[i_t])
        else:
            self.selected_contexts.append(np.zeros(self.d))
        if self.t >= self.T0:
            self.theta = solve_mle(self.offered_contexts, self.selected_contexts, self.d, init_theta=self.theta)
        self.t += 1
