import gym
from copy import deepcopy
from typing import Optional
from gym.spaces.utils import flatdim
from .base_agent import BaseAgent
from .fac import FAC

def get_agent(cfg: dict,
              discrete_action: bool,
              device: str,
              env: object,
              actor: object,
              critic: object,
              replay_buffer: object,

              ) -> BaseAgent:
    if isinstance(env, gym.Env):
        state_dim = flatdim(env.observation_space)
        action_dim = flatdim(env.action_space)
        action_space = env.action_space
    else:
        raise NotImplementedError
            
    if cfg.name == "fac":
            behavior_policy = deepcopy(actor)
            return FAC(discrete_action=discrete_action,
                            action_dim=action_dim,
                            state_dim=state_dim,
                            gamma=cfg.gamma,
                            batch_size=cfg.buffer.batch_size,
                            alpha=cfg.alpha,
                            device=device,
                            actor=actor,
                            behavior_policy=behavior_policy,
                            critic=critic,
                            replay_buffer=replay_buffer,
                            logq_entropic_index=cfg.logq_entropic_index,
                            expq_entropic_index=cfg.expq_entropic_index,
                            fname=cfg.fname,
                            num_terms=cfg.num_terms,
                            ratio_eps=cfg.ratio_eps,
                            symmetric_coef=cfg.symmetric_coef,
                            )             
    else:
            raise NotImplementedError



if __name__ == "__main__":
    pass