from abc import ABCMeta
import logging
import ray
import numpy as np

from trainer.trainer import Trainer
from agents.curriculum.task_generators import UniformTaskGenerator, ContextualBanditTaskGenerator, ALPGMMTaskGenerator
from utils.typing import TrainerConfigDict, ResultDict

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s",
    handlers=[logging.FileHandler("teacher.log"), logging.StreamHandler()],
)
logger = logging.getLogger(__name__)


class Teacher(metaclass=ABCMeta):
    def __init__(
        self,
        trainer: Trainer,
        trainer_config: TrainerConfigDict,
    ):
        """Initializes a Teacher instance."""
        self.trainer = trainer
        self.config = trainer_config

    def update_curriculum(self, result: ResultDict) -> None:
        """Method containing curriculum logic. Called after train step."""
        return result


class UniformTeacher(Teacher):
    def __init__(
        self,
        trainer: Trainer,
        trainer_config: TrainerConfigDict,
        num_agents=None,
    ):
        super().__init__(
            trainer,
            trainer_config,
        )
        self.task_generator = UniformTaskGenerator.options(name="task_generator").remote(
            seed=self.config["seed"],
            num_agents=num_agents,
        )


class ALPGMMTeacher(Teacher):
    def __init__(
        self,
        trainer: Trainer,
        trainer_config: TrainerConfigDict,
        num_agents=None,
        gmm_fitness_func="aic",
        warm_start=False,
        nb_em_init=1,
        fit_rate=250,
        alp_window_size=None,
        potential_ks=np.arange(2, 11, 1),
        random_task_ratio=0.2,
        nb_bootstrap=None,
        initial_dist=None,
    ):
        super().__init__(
            trainer,
            trainer_config,
        )
        self.task_generator = ALPGMMTaskGenerator.options(name="task_generator").remote(
            seed=self.config["seed"],
            num_agents=num_agents,
            gmm_fitness_func=gmm_fitness_func,
            warm_start=warm_start,
            nb_em_init=nb_em_init,
            fit_rate=fit_rate,
            alp_window_size=alp_window_size,
            potential_ks=potential_ks,
            random_task_ratio=random_task_ratio,
            nb_bootstrap=nb_bootstrap,
            initial_dist=initial_dist,
        )

    def update_curriculum(self, result: ResultDict) -> dict:
        infos = ray.get(self.task_generator.get_infos.remote())
        for k, v in infos.items():
            result[f"alp_gmm_{k}"] = v
        return result


class ContextualBanditTeacher(Teacher):
    def __init__(
        self,
        trainer: Trainer,
        trainer_config: TrainerConfigDict,
        num_contexts=3,
        gamma=0.3,
        num_agents=None,
        min_rew=0,
        max_rew=1,
    ):
        super().__init__(
            trainer,
            trainer_config,
        )
        self.task_generator = ContextualBanditTaskGenerator.options(name="task_generator").remote(
            seed=self.config["seed"],
            num_contexts=num_contexts,
            gamma=gamma,
            num_agents=num_agents,
            min_rew=min_rew,
            max_rew=max_rew,
        )

    def update_curriculum(self, result: ResultDict) -> dict:
        if self.trainer.get_policy("high_level_policy"):  # HRL
            context = getattr(self.trainer.get_policy("high_level_policy").model, "last_hx", [0])
        elif self.trainer.get_policy():  # ATT-COM
            context = getattr(self.trainer.get_policy().model, "last_hx", [0])
        elif self.trainer.get_policy("agent_0"):  # PPO
            context = getattr(self.trainer.get_policy("agent_0").model, "last_hx", [0])
        else:
            context = [0]

        self.task_generator.update_context.remote(context)

        return result


# class TimeTeacher(Teacher):
#     def __init__(
#         self,
#         trainer: Trainer,
#         trainer_config: TrainerConfigDict,
#     ):
#         super().__init__(
#             trainer,
#             trainer_config,
#         )
#
#         self.task_generator = TimeBasedTaskGenerator.options(name="task_generator").remote(
#             self.match_func,
#             win_rate_threshold,
#             iter_threshold,
#             newest_prob,
#             max_league_size,
#         )
#
#     def update_curriculum(self, result: ResultDict) -> None:
#         timesteps_total = result["timesteps_total"]
#         self.task_generator.update_timesteps.remote(timesteps_total)
#         if "evaluation" in result:
#             self.task_generator.set_sample_flag.remote(True)
