import os
from abc import ABC, abstractmethod
from copy import deepcopy

import torch
import torch.nn as nn


class Principal(ABC):
    def __init__(self, args, agent_nets):
        self.args = args
        self.agent_nets = agent_nets
        self.og_nets = deepcopy(self.agent_nets.state_dict())

    def agent_step(self, metrics):
        """Default for one step of agent net optimisation.

        Stepping agent nets is handled by principals, and the default step is using
        a regular torch optimizer on a combination of PG, entropy and value losses.

        Principals using this default step need agent_opt, agent_nets and args fields.

        Args:
            metrics (Metrics): losses and loss metrics to use for update
        """

        """ Form loss. """
        loss = metrics.pg_loss - self.args.agent_ent_coef * metrics.entropy_loss + metrics.v_loss * self.args.vf_coef

        """ Step agent nets. """
        self.agent_opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.agent_nets.parameters(), self.args.max_grad_norm)
        self.agent_opt.step()

    @abstractmethod
    def set_tax_vals(self, *args, **kwargs):
        pass

    @abstractmethod
    def after_episode(self, *args, **kwargs):
        pass

    def make_save_dir(self, episode_number):
        """Create directory we save to.

        Args:
            episode_number (int): current episode number
        """

        try:
            os.mkdir(f"./saved_params")
        except FileExistsError:
            pass
        try:
            os.mkdir(f"./saved_params/ep{episode_number}")
        except FileExistsError:
            pass

    def save_params(self, episode_number):
        """Default for all net parameter saving - saves just agent nets.

        Principals all contain various nets, so are made responsible for saving the
        parameters of all models. Default is to save just the agent nets.

        Principals using this default need an agent_nets field.

        Args:
            episode_number (int): current episode number
        """

        self.make_save_dir(episode_number)
        torch.save(
            self.agent_nets.state_dict(),
            f"./saved_params/ep{episode_number}/agents_net_ep{episode_number}.pt",
        )
