import torch
from meta.base import Base


class Opponent(Base):
    """Class for training an opponent
    Args:
        log (dict): Dictionary that contains python logging
        tb_writer (SummeryWriter): Used for tensorboard logging
        args (argparse): Python argparse that contains arguments
        name (str): Specifies agent's name
        i_agent (int): Agent index among the agents in the shared environment
        rank (int): Used for thread-specific meta-agent for multiprocessing. Default: -1
    """
    def __init__(self, log, tb_writer, args, name, i_agent, rank=-1):
        super(Opponent, self).__init__(log, tb_writer, args, name, i_agent, rank)

        self._set_dim()
        self._set_linear_baseline()
        self._set_policy()

    def _set_policy(self):
        # For repeated matrix game experiments, we consider tabular representation for 
        # the opponent's policy, which will be directly set when set_persona() is called.
        # Thus, returning instead of setting policy
        if self.args.env_name == "IPD-v0":
            self.is_tabular_policy = True
            return

    def set_persona(self, persona):
        if self.args.env_name == "IPD-v0":
            self.actor = torch.nn.Parameter(torch.from_numpy(persona).float(), requires_grad=True)
            self.log[self.args.log_name].info("[{}] Set persona: {}".format(self.name, persona))

    # def inner_update(self, actor, memory, iteration_, iteration, rank, is_meta_train):
    #     if iteration_ == self.args.train_max_iteration - 1:
    #         return None

    #     actor_loss = self._get_inner_loss(memory, is_meta_train)
    #     actor_grad = torch.autograd.grad(actor_loss, [actor], create_graph=is_meta_train)

    #     phi = actor - 1. * actor_grad[0]

    #     if is_meta_train:
    #         self.tb_writer.add_scalars(
    #             str(rank) + "/loss/inner_actor_loss", {str(self.i_agent): actor_loss.data.numpy()}, iteration)
    # 
    #     return phi
