import numpy as np
from typing import Optional, Dict, Any, Tuple

class OPT3():
    def __init__(self, K = 4, dim = 5, alpha = 2, gamma = 1.3, rmodel = None, cmodel = None, D = 0, budget = 0):
        self.K = K
        self.dim = dim
        self.Q_budget = 0
        self.Q_ROI = 0
        self.last_pi = np.ones(K) / K
        self.pi = None
        self.alpha = alpha
        self.gamma = gamma
        self.rmodel = rmodel
        self.cmodel = cmodel

        self.Q_budget_list, self.Q_ROI_list = [], []
        self.opti = 0
        self.minn = np.zeros(K)
        self.D = D
        self.budget = budget

    def sample(self, dist):
        return np.argmax(np.random.multinomial(1,dist.flatten()))

    def get_pi(self, L_hat):
        
        from scipy.optimize import minimize
        def func1(x):
            return np.sum(x * L_hat) + self.alpha * np.sum((x - self.last_pi) ** 2)
        def func2(x):
            return np.sum(x * L_hat) + self.alpha * np.sum(x * np.log(x / self.last_pi))
        def func3(x):
            return np.sum(x * L_hat) + self.alpha * np.sum(self.last_pi * np.log(self.last_pi / x))
        def func4(x):
            return np.sum(x * L_hat) + self.alpha * np.sqrt(np.sum((x - self.last_pi) ** 2))
        func_list = [func1, func2, func3, func4]
        constraints = [
            {'type': 'eq', 'fun': lambda x: np.sum(x) - 1},
            {'type': 'ineq', 'fun': lambda x: x},
            {'type': 'ineq', 'fun': lambda x: 1 - x}
        ]
        result = minimize(func_list[self.D], self.last_pi, constraints=constraints)
        return result.x

    def act(self, ctx):
        context = ctx.get("context")
        cb = self.cmodel.predict(context) - self.budget
        ROI_vio = self.cmodel.predict(context) - self.gamma * self.rmodel.predict(context)
        v1 = self.Q_budget * cb + self.Q_ROI * ROI_vio - self.rmodel.predict(context)
        self.pi = self.get_pi(v1)
        if self.alpha == 0:
            action = np.argmin(v1)
            self.pi = np.zeros(self.K)
            self.pi[action] = 1
            return action
        # self.pi = self.get_pi(v1)
        if np.max(self.pi) > 1.01 or np.min(self.pi) < -0.01:
            print("?")
            exit(0)
        self.pi[self.pi < 0] = 0
        self.pi[self.pi > 1] = 1
        self.pi /= np.sum(self.pi)
        return self.sample(self.pi)
        
    def update(self, reward, cost, info):
        context = info.get("context")
        action = int(info.get("bid"))
        if self.cmodel.predict(context)[action] > cost:
            self.opti += 1
        self.Q_budget_list.append(self.Q_budget)
        self.Q_ROI_list.append(self.Q_ROI)
        
        E_budget_1 = np.sum(self.last_pi * self.cmodel.predict(context))
        self.rmodel.update(context[action], reward)
        self.cmodel.update(action, cost)
        E_budget_2 = np.sum(self.pi * self.cmodel.predict(context))
        self.Q_budget = self.Q_budget - E_budget_1 + 2 * E_budget_2
        if self.Q_budget < 0:
            self.Q_budget = 0
            
        E_ROI_1 = np.sum(self.last_pi * (self.cmodel.predict(context) - self.gamma * self.rmodel.predict(context)))
        E_ROI_2 = np.sum(self.pi * (self.cmodel.predict(context) - self.gamma * self.rmodel.predict(context)))
        self.Q_ROI = self.Q_ROI - E_ROI_1 + 2 * E_ROI_2
        if self.Q_ROI < 0:
            self.Q_ROI = 0
        self.last_pi = self.pi