import time
import ray

import numpy as np

from expground.types import AgentID, PolicyID, LambdaType, Dict, Union
from expground.learner.base_learner import Learner
from expground.utils.stoppers import DEFAULT_STOP_CONDITIONS, get_stopper


class BiLevel(Learner):

    NAME = "bi_level"

    def __init__(
        self,
        experiment: str,
        summary_writer=None,
        seed: int = None,
        evaluation_worker_num: int = 0,
        ray_mode: bool = False,
        agent_mapping: LambdaType = ...,
    ) -> None:
        super(BiLevel, self).__init__(
            experiment,
            summary_writer=summary_writer,
            seed=seed,
            evaluation_worker_num=evaluation_worker_num,
            ray_mode=ray_mode,
            agent_mapping=agent_mapping,
        )

    def hyper_gradient_estimation(self):
        raise NotImplementedError

    def optimize_leader_objective(self, x, y):
        raise NotImplementedError

    def learn(
        self,
        sample_config: Union[Dict, LambdaType],
        stop_conditions: Dict = None,
        inner_conditions: Dict = None,
    ) -> Dict:
        self.stopper = get_stopper(stop_conditions or DEFAULT_STOP_CONDITIONS)
        sampler = None

        # init leader objective here
        y = None
        while not self.stopper.is_terminal():
            # optimize leader by D times
            x = None
            for n_iter in range(inner_conditions["D"]):
                y = self.optimize_leader(x, y)
            # estimate hypergradient via AID or ITD
            grad_X = self.hyper_gradient_estimation()
            # update x
            x = None

        return x, y
