import  LiteEFG
from typing import Literal

class graph(LiteEFG.Graph):
    def __init__(self, eta=0.001, tau=0.001, gamma=0.001, regularizer: Literal["Euclidean", "Entropy"]="Entropy", feedback="Q", weighted=False, bidilated=True):
        super().__init__()

        self.eta = eta
        self.tau = tau
        self.gamma = gamma
        self.regularizer = regularizer
        self.feedback = feedback

        with LiteEFG.backward(is_static=True):

            self.alpha = 1.0
            if weighted:
                self.alpha = LiteEFG.const(1, 1.0)
                self.alpha.inplace(LiteEFG.aggregate(self.alpha, "sum"))
                self.alpha.inplace((self.alpha.max() + 1) * 2)

            self.ev = LiteEFG.const(1, 0.0)
            self.coef = self.tau
            self.mu = self.subtree_size.normalize(p_norm=1.0, ignore_negative=True)
            self.eta_coef = self.eta * self.coef
            
            self.u = LiteEFG.const(self.action_set_size, 1.0 / self.action_set_size)
            self.bar_u = self.u.copy()

        with LiteEFG.backward():

            if(self.feedback in ["Q", "traj-Q", "counterfactual"]):
                self._full_information()
            elif self.feedback == "Outcome":
                self._outcome_sampling()

        print("===============Graph is ready for QFR===============")
        print("eta: %f, tau: %f, gamma: %f, regularizer: %s, feedback: %s" % (self.eta, self.tau, self.gamma, self.regularizer, self.feedback))
        print("====================================================\n")
    
    def _full_information(self):
        if self.feedback == "counterfactual":
            self.m_th = 1.0
        elif self.feedback == "traj-Q":
            self.m_th = 1.0 / self.reach_prob
        else:
            self.m_th = self.opponent_reach_prob
        
        self.eta_tau = self.eta_coef / self.m_th * self.opponent_reach_prob + 1 # 1 + eta * tau / m_th * mu_{-p}(s)
        
        if self.regularizer == "Euclidean":
            reg = LiteEFG.euclidean(self.u) * self.coef
        else:
            reg = LiteEFG.negative_entropy(self.u, shifted=True) * self.coef

        bidilated_reg = reg * self.reach_prob
        gradient = LiteEFG.aggregate(self.ev, "sum") + self.utility \
                                                    + LiteEFG.aggregate(bidilated_reg, "sum", player="opponents")
        
        self.ev.inplace(LiteEFG.dot(gradient, self.u) - reg * self.opponent_reach_prob)

        gradient.inplace(gradient / self.m_th * self.eta / self.alpha)

        self._update(self.bar_u, self.bar_u, gradient)
        self._update(self.u, self.bar_u, gradient)
    
    def _outcome_sampling(self):
        #self.m_th = 1.0 / self.reach_prob
        self.eta_tau = self.eta_coef + 1 # 1 + eta * tau

        if self.regularizer == "Euclidean":
            reg = LiteEFG.euclidean(self.u) * self.coef
        else:
            reg = LiteEFG.negative_entropy(self.u, shifted=True) * self.coef

        gradient = LiteEFG.aggregate(self.ev, "sum") + self.utility \
                                                    + LiteEFG.aggregate(reg, "sum", player="opponents")
        self.ev.inplace(LiteEFG.sum(gradient) - reg)

        gradient.inplace(gradient / self.u * self.eta / self.alpha)

        self._update(self.bar_u, self.bar_u, gradient)
        self._update(self.u, self.bar_u, gradient)

    def _update(self, upd_u, ref_u, gradient):
        if self.regularizer == "Euclidean":
            upd_u.inplace((ref_u + gradient) / self.eta_tau)
            upd_u.inplace(upd_u.project(distance="L2", gamma=self.gamma, mu=self.mu))
        else:
            upd_u.inplace((ref_u.log() + gradient) / self.eta_tau)
            upd_u.inplace(upd_u - upd_u.max())
            upd_u.inplace(upd_u.exp())
            upd_u.inplace(upd_u.project(distance="KL", gamma=self.gamma, mu=self.mu))
    
    def update_graph(self, env : LiteEFG.Environment) -> None:
        env.update(self.u)
    
    def current_strategy(self) -> LiteEFG.GraphNode:
        return self.u