
import LiteEFG
from .BaseSolver import BaseAsymSolver

class AsymPDGDA(BaseAsymSolver):
    def __init__(
        self,
        eta=0.1,
        tau=0.1,
        weighted=False,
    ):
        self.eta = eta
        self.tau = tau
        self.weighted = weighted
        # we must call super init to initialize solvers after setting parameters
        super().__init__()
    
    def create_solver(self, player_id: int) -> LiteEFG.Graph:
        taus = [0.0, 0.0]
        taus[player_id] = self.tau
        return _AsymPDOMD(
            eta=self.eta,
            taus=taus,
            weighted=self.weighted,
        )

class _AsymPDOMD(LiteEFG.Graph):
    def __init__(
        self,
        eta=0.1,
        taus=[0.1, 0.1],
        weighted=False,
    ):
        super().__init__()
        self.eta = eta
        self.taus = taus
        self.timestep = 0

        # Create a new graph for CFR
        with LiteEFG.backward(is_static=True):
            self.alpha = 1.0
            if weighted:
                self.alpha = LiteEFG.const(1, 1.0)
                self.alpha.inplace(LiteEFG.aggregate(self.alpha, "sum"))
                self.alpha.inplace((self.alpha.max() + 1) * 2)

            ev = LiteEFG.const(size=1, val=0.0)
            self.tau1 = LiteEFG.const(1, self.taus[0])
            self.tau2 = LiteEFG.const(1, self.taus[1])
            self.eta_coef = self.alpha / eta
            # self.coef = self.tau
            self.u = LiteEFG.const(self.action_set_size, 1.0 / self.action_set_size)

        with LiteEFG.backward(color=0):
            gradient = LiteEFG.aggregate(ev, "sum") + self.utility
            prev_u = self.u.copy()
            self._update(gradient, self.u, prev_u, self.tau1)
            self._get_ev(gradient, ev, self.u, prev_u, self.tau1)

        with LiteEFG.backward(color=1):
            gradient = LiteEFG.aggregate(ev, "sum") + self.utility
            prev_u = self.u.copy()
            self._update(gradient, self.u, prev_u, self.tau2)
            self._get_ev(gradient, ev, self.u, prev_u, self.tau2)
            

        print("===============Graph is ready for AsymPDGDA===============")
        print("eta: %f, taus: %s" % (self.eta, self.taus))
        print("=====================================================\n")


    def _get_ev(self, gradient, ev, strategy, ref_strategy, tau):
        ev.inplace(
            LiteEFG.dot(gradient, strategy)
            - LiteEFG.euclidean(strategy - ref_strategy) * self.eta_coef
            - (LiteEFG.dot(ref_strategy, strategy)) * tau * self.alpha
        )

    def _update(self, gradient, upd_u, ref_u, tau):
        gradient.inplace(gradient - ref_u * self.alpha * tau)
        gradient_div = gradient / self.eta_coef
        upd_u.inplace(ref_u + gradient_div)
        upd_u.inplace(upd_u.project(distance="L2"))

    def update_graph(self, env: LiteEFG.Environment) -> None:
        env.update(self.u, upd_player=1, upd_color=[0])
        env.update(self.u, upd_player=2, upd_color=[1])
        
    def current_strategy(self) -> list[LiteEFG.GraphNode]:
        return self.u
