import math
import torch
import torch.nn.functional as F
import os
from rl_utils import nets, device, adv_estimator
from itertools import chain
import numpy as np
from torch.optim import optimizer


class Agent:
    def __init__(
        self,
        obs_space_size,
        act_space_size,
        log_std_low=-10.0,
        log_std_high=2.0,
        actor_net_cls=nets.StochasticActor,
        critic_net_cls=nets.BigCritic,
        hidden_size=256,
    ):
        self.actor = actor_net_cls(
            obs_space_size,
            act_space_size,
            log_std_low,
            log_std_high,
            dist_impl="pyd",
            hidden_size=hidden_size,
        )
        self.critic1 = critic_net_cls(obs_space_size, act_space_size, hidden_size)
        self.critic2 = critic_net_cls(obs_space_size, act_space_size, hidden_size)

        self.critic_optimizer = torch.optim.Adam(
            chain(self.critic1.parameters(), self.critic2.parameters(),), lr=3e-4,
        )
        self.online_actor_optimizer = torch.optim.Adam(
            self.actor.parameters(), lr=3e-4,
        )
        # Trick to make it easier to reload `log_alpha`'s optimizer when changing devices.
        self._log_alpha = torch.nn.Linear(1, 1, bias=False)
        self.log_alpha = torch.Tensor([math.log(0.1)])
        self.log_alpha.requires_grad = True
        self.log_alpha_optimizer = torch.optim.Adam(
            [self.log_alpha], lr=1e-4, betas=(0.5, 0.999)
        )
        self.target_entropy = -float(act_space_size)

    @property
    def log_alpha(self):
        return self._log_alpha.weight

    @log_alpha.setter
    def log_alpha(self, v):
        assert isinstance(v, torch.Tensor)
        self._log_alpha.weight = torch.nn.Parameter(v)

    def to(self, device):
        self.actor.to(device)
        self.critic1.to(device)
        self.critic2.to(device)

        # Reload state_dict of optimizer to account for device change
        # From https://github.com/pytorch/pytorch/issues/8741
        self.critic_optimizer.load_state_dict(self.critic_optimizer.state_dict())
        self.online_actor_optimizer.load_state_dict(
            self.online_actor_optimizer.state_dict()
        )
        self._log_alpha.to(device)
        self.log_alpha_optimizer.load_state_dict(self.log_alpha_optimizer.state_dict())

    def share_memory_(self):
        self.actor.share_memory()
        self.critic1.share_memory()
        self.critic2.share_memory()
        self.log_alpha.share_memory_()

    def eval(self):
        self.actor.eval()
        self.critic1.eval()
        self.critic2.eval()

    def train(self):
        self.actor.train()
        self.critic1.train()
        self.critic2.train()
        if not self.log_alpha.requires_grad:
            self.log_alpha.requires_grad = True

    def save(self, path, id_):
        actor_path = os.path.join(path, f"actor_{id_}.pt")
        critic1_path = os.path.join(path, f"critic1_{id_}.pt")
        critic2_path = os.path.join(path, f"critic2_{id_}.pt")
        torch.save(self.actor.state_dict(), actor_path)
        torch.save(self.critic1.state_dict(), critic1_path)
        torch.save(self.critic2.state_dict(), critic2_path)

    def load(self, path, id_):
        actor_path = os.path.join(path, f"actor_{id_}.pt")
        critic1_path = os.path.join(path, f"critic1_{id_}.pt")
        critic2_path = os.path.join(path, f"critic2_{id_}.pt")
        self.actor.load_state_dict(torch.load(actor_path))
        self.critic1.load_state_dict(torch.load(critic1_path))
        self.critic2.load_state_dict(torch.load(critic2_path))

    def forward(self, state, from_cpu=True):
        if from_cpu:
            state = self.process_state(state)
        self.actor.eval()
        with torch.no_grad():
            act_dist = self.actor.forward(state)
            act = act_dist.mean
        self.actor.train()
        if from_cpu:
            act = self.process_act(act)
        return act

    def sample_action(self, state, from_cpu=True):
        if from_cpu:
            state = self.process_state(state)
        self.actor.eval()
        with torch.no_grad():
            act_dist = self.actor.forward(state)
            act = act_dist.sample()
        self.actor.train()
        if from_cpu:
            act = self.process_act(act)
        return act

    def process_state(self, state):
        return torch.from_numpy(np.expand_dims(state, 0).astype(np.float32)).to(device)

    def process_act(self, act):
        return np.squeeze(act.clamp(-1.0, 1.0).cpu().numpy(), 0)
