from gfn_subtb_grid.agents.losses import BaseLoss
from hive.agents.agent import Agent
from collections.abc import Iterable
from torchtyping import TensorType
from typing import Callable, Dict, Iterable
import torch
import copy

REDUCTIONS = {
    'mean': torch.mean,
    'sum': torch.sum,
    'identity': lambda x: x,
    None: lambda x: x,
}

class BaseAgent(Agent):
    def __init__(
        self,
        model: torch.nn.Module,
        loss_fxn: BaseLoss,
        obs_dim: int,
        action_dim: int,
        optimizer_config: Dict[str, object],
        id: int = 0,
    ):
        self._loss_fxn = loss_fxn
        self.model = model
        super().__init__(obs_dim, action_dim, id)

        self.optimizer_config = optimizer_config
        self.optim, self.lr_scheduler = self._build_optim(
            copy.deepcopy(self.optimizer_config)
        )

    def loss(
        self,
        update_infos: Dict[str, object],
        reduction: str = 'mean',
        loss_fxn: BaseLoss = None
    ) -> TensorType:
        augmented_infos = self._get_loss_infos(update_infos)
        augmented_infos['agent'] = self

        loss_fxn = loss_fxn if loss_fxn is not None else self._loss_fxn
        unreduced_loss : TensorType['batch_size'] = loss_fxn(augmented_infos)
        return REDUCTIONS[reduction](unreduced_loss)

    def _get_loss_infos(
        self,
        update_infos: Dict[str, object]
    ) -> Dict[str, object]:
        return update_infos

    def parameters(self) -> Iterable:
        return self.model.parameters()

    def get_agent_state(
        self,
        all_states: TensorType['batch_size', 'horizon', 'ndim_times_side_len'],
        all_actions: TensorType['batch_size', 'horizon'],
        all_dones: TensorType['batch_size', 'horizon'],
        all_rewards: TensorType['batch_size'],
        iter_num: int
    ) -> TensorType:
        return all_states[:, iter_num]

    @property
    def does_grad_update(self) -> bool:
        return True

    def _build_optim(
        self,
        optimizer_config: Dict[str, object],
        parameters: Iterable = None
    ) -> torch.optim.Optimizer:
        optim_type = optimizer_config.pop('type', torch.optim.Adam)
        lr_scheduler_config = optimizer_config.pop('lr_scheduler_config', None)
        lr = optimizer_config.pop('lr')

        optim = optim_type(
            parameters or self.parameters(),
            lr=lr,
            **optimizer_config
        )

        lr_scheduler = None
        if lr_scheduler_config is not None:
            scheduler_type = lr_scheduler_config.pop('type')
            lr_scheduler = scheduler_type(optim, **lr_scheduler_config)

        return optim, lr_scheduler

    def get_log_pf(
        self,
        states: TensorType['batch_size', 'horizon', 'obs_dim'],
        actions: TensorType['batch_size', 'horizon']
    ):
        return torch.zeros_like(actions)

    def get_log_pb(
        self,
        states: TensorType['batch_size', 'horizon', 'obs_dim'],
        actions: TensorType['batch_size', 'horizon']
    ):
        return torch.zeros_like(actions)

    def get_metrics(self) -> Dict[str, object]:
        return {}

    # A method to allow subclasses to perform logic at the end of an
    # epoch if they would like.
    def end_epoch(self) -> None:
        pass
