

import LiteEFG
from .BaseSolver import BaseSolver


class SymPDGDA(BaseSolver):
    def __init__(self, eta=0.1, tau=0.1, weighted=False):
        super().__init__()
        self.eta = eta
        self.tau = tau
        self.timestep = 0
        
        # Create a new graph for CFR
        with LiteEFG.backward(is_static=True):
        
            self.alpha = 1.0
            if weighted:
                self.alpha = LiteEFG.const(1, 1.0)
                self.alpha.inplace(LiteEFG.aggregate(self.alpha, "sum"))
                self.alpha.inplace((self.alpha.max() + 1) * 2)

            ev = LiteEFG.const(size=1, val=0.0)

            self.tau = LiteEFG.const(1, tau)
            self.eta_coef = self.alpha / eta
            self.coef = self.tau
            self.u = LiteEFG.const(self.action_set_size, 1.0 / self.action_set_size)

        with LiteEFG.backward():

            gradient = LiteEFG.aggregate(ev, "sum") + self.utility
            prev_u = self.u.copy()
            self._update(gradient, self.u, prev_u)
            self._get_ev(gradient, ev, self.u, prev_u)
        
        # with LiteEFG.backward(color=1):
        #     self.tau.inplace(self.tau * 0.5)
        #     self.coef.inplace(self.tau)

        print("===============Graph is ready for SymPDGDA===============")
        print("eta: %f" % (self.eta))
        print("=====================================================\n")
    
    def _get_ev(self, gradient, ev, strategy, ref_strategy):
        ev.inplace(LiteEFG.dot(gradient, strategy) - LiteEFG.euclidean(strategy - ref_strategy) * self.eta_coef
                                                        - (LiteEFG.dot(ref_strategy, strategy)) * self.tau * self.alpha)
            
    def _update(self, gradient, upd_u, ref_u):
        gradient.inplace(gradient - ref_u * self.alpha * self.tau)
        gradient_div = gradient / self.eta_coef
        upd_u.inplace(ref_u + gradient_div)
        upd_u.inplace(upd_u.project(distance="L2"))

    def update_graph(self, env : LiteEFG.Environment) -> None:
        env.update(self.u, upd_player=1)
        env.update(self.u, upd_player=2)
    
    def current_strategy(self) -> LiteEFG.GraphNode:
        return self.u
    