import abc
import numpy as np
import torch
from torch.distributions import Categorical


class Policy(abc.ABC):
    @abc.abstractmethod
    def act(self, state):
        pass
    
    @abc.abstractmethod
    def update(self, q_model, state, action, reward, state_prime, done, gamma=0.99):
        pass
    

class TabularPolicy(Policy):
    def __init__(self, observation_space, action_space, action_table=None):
        self.num_actions = action_space.n
        self.num_states = observation_space.n
        if action_table is None:
            self.action_table = np.random.randint(0, self.num_actions, self.num_states)
        else:
            self.action_table = action_table
    
    def act(self, state):
        return self.action_table[state]
    
    def dist(self, state):
        if isinstance(state, int) or state.ndim == 0:
            dist = np.zeros(self.num_actions)
            dist[self.action_table[state]] = 1
        elif state.ndim == 1:
            dist = np.zeros((state.shape[0], self.num_actions))
            dist[np.arange(state.shape[0]), self.action_table[state]] = 1
        else:
            raise ValueError("state must be 0D or 1D")
        return dist
    
    def update(self, q_model, state, action, reward, state_prime, done, gamma=0.99):
        q_table = q_model(np.arange(self.num_states, dtype=np.int64))
        self.action_table = np.argmax(q_table, axis=1)
        return None


class LinearPolicy(Policy):
    def __init__(self, observation_space, action_space, Phi, lr=0.01,
                 alpha_multiplier=1.0, use_automatic_entropy_tuning=False, target_entropy=0.1, alpha_lr=1e-2):
        self.Phi = Phi
        self.feature_dim = Phi.shape[1]
        self.num_states = observation_space.n
        self.num_actions = action_space.n
        self.model = torch.nn.Linear(self.feature_dim, 1, bias=False, device=self.Phi.device)
        self.optim = torch.optim.Adam(self.model.parameters(), lr=lr)

        self.alpha_multiplier = alpha_multiplier
        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            self.log_alpha = torch.tensor(np.log(alpha_multiplier), dtype=torch.float32, requires_grad=True)
            self.target_entropy = target_entropy
            self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=alpha_lr)
    
    def forward(self, state):
        bs = state.shape[0] if state.ndim > 0 else 1
        x = (state * self.num_actions).reshape([-1, 1])\
            + torch.arange(self.num_actions, dtype=torch.int64, device=state.device).reshape([1, -1])
        x = x.reshape([-1])
        return Categorical(logits=self.model(self.Phi[x]).reshape([bs, self.num_actions]))

    def act(self, state):
        is_torch = isinstance(state, torch.Tensor)
        if not is_torch:
            state = torch.tensor(state, dtype=torch.int64, device=self.Phi.device)

        a = self.forward(state).sample()
        if not is_torch:
            a = a.detach().cpu().numpy()
        return a

    def dist(self, state):
        is_torch = isinstance(state, torch.Tensor)
        if not is_torch:
            state = torch.tensor(state, dtype=torch.int64, device=self.Phi.device)
            
        dist = self.forward(state).probs
        if not is_torch:
            dist = dist.detach().cpu().numpy()
        return dist

    def _policy_loss(self, q_model, state, action, reward, state_prime, done, alpha=0.0, gamma=0.99, bc=False):
        bs = state.shape[0] if state.ndim > 0 else 1
        x = (state * self.num_actions).reshape([-1, 1])\
            + torch.arange(self.num_actions, dtype=torch.int64, device=state.device).reshape([1, -1])
        x = x.reshape([-1])
        
        dist = Categorical(logits=self.model(self.Phi[x]).reshape([bs, self.num_actions]))
        if bc:
            return -(dist.log_prob(action) + alpha * dist.entropy()).mean()
        else:
            q = q_model(state)
            q = (q * dist.probs).sum(-1)
            return -(q + alpha * dist.entropy()).mean()
    
    def update_alpha(self, policy_dist):
        if self.use_automatic_entropy_tuning:
            policy_entropy = policy_dist.entropy()
            alpha_loss = self.log_alpha * (policy_entropy - self.target_entropy).mean().detach()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
            
            alpha = torch.exp(self.log_alpha) * self.alpha_multiplier
            info = {
                'alpha': alpha.detach().cpu().numpy(),
                'alpha_loss': alpha_loss.detach().cpu().numpy(),
                'policy_entropy': policy_entropy.mean().detach().cpu().numpy(),
            }
        else:
            alpha = self.alpha_multiplier
            alpha_loss = 0.0
            info = {}
        
        return alpha, alpha_loss, info
    
    def update(self, q_model, state, action, reward, state_prime, done, gamma=0.99, bc=False):
        alpha, alpha_loss, alpha_info = self.update_alpha(self.forward(state))
        
        policy_loss = self._policy_loss(q_model, state, action, reward, state_prime, done, alpha=alpha, gamma=gamma, bc=bc)
        self.optim.zero_grad()
        policy_loss.backward()
        self.optim.step()

        info = {'policy_loss': policy_loss.detach().cpu().numpy()}
        info.update(alpha_info)
        return info


class POPLinearPolicy(LinearPolicy):
    def __init__(self, observation_space, action_space, Phi, lr=0.01,
                 alpha_multiplier=1.0, use_automatic_entropy_tuning=False, target_entropy=0.1, alpha_lr=1e-2,
                 beta_multiplier=1.0, use_automatic_kl_tuning=False, target_kl=1.0, beta_lr=1e-2):
        super().__init__(observation_space, action_space, Phi, lr=lr,
                         alpha_multiplier=alpha_multiplier, use_automatic_entropy_tuning=use_automatic_entropy_tuning,
                         target_entropy=target_entropy, alpha_lr=alpha_lr)
        self.beta_multiplier = beta_multiplier
        self.use_automatic_kl_tuning = use_automatic_kl_tuning
        if self.use_automatic_kl_tuning:
            self.log_beta = torch.tensor(np.log(beta_multiplier), dtype=torch.float32, requires_grad=True)
            self.target_kl = target_kl
            self.beta_optim = torch.optim.Adam([self.log_beta], lr=beta_lr)

    def update_beta(self, reweight):
        if self.use_automatic_kl_tuning:
            reweight = reweight / reweight.mean()
            beta_loss = -self.log_beta * (reweight * torch.log(reweight) - self.target_kl).mean().detach()
            self.beta_optim.zero_grad()
            beta_loss.backward()
            self.beta_optim.step()

            beta = torch.exp(self.log_beta) * self.beta_multiplier
            info = {
                'beta': beta.detach().cpu().numpy(),
                'beta_loss': beta_loss.detach().cpu().numpy(),
                'kl': (reweight * torch.log(reweight)).mean().detach().cpu().numpy(),
            }
        else:
            beta = self.beta_multiplier
            beta_loss = 0.0
            info = {}

        return beta, beta_loss, info

    def update(self, q_model, state, action, reward, state_prime, done, gamma=0.99, bc=False):
        x = state * self.num_actions + action
        reweight = q_model.reweight[x]

        alpha, alpha_loss, alpha_info = self.update_alpha(self.forward(state))
        beta, beta_loss, beta_info = self.update_beta(reweight)

        policy_loss = self._policy_loss(q_model, state, action, reward, state_prime, done, gamma=gamma, alpha=alpha)

        xp = (state_prime * self.num_actions).reshape([-1, 1])\
            + torch.arange(self.num_actions, dtype=torch.int64, device=state.device).reshape([1, -1])
        xp = xp.reshape([-1])
        distp = Categorical(logits=self.model(self.Phi[xp]).reshape([state.shape[0], self.num_actions]))
        # m_b = (q_model.Phi @ q_model.b)[x]
        # m_a_prime = (q_model.Phi @ q_model.a)[xp]
        
        a_2_norm = torch.svd(q_model.a, compute_uv=False)[1].max().detach()
        b_2_norm = torch.svd(q_model.b, compute_uv=False)[1].max().detach()
        # m_a = q_model.a_mag * (q_model.Phi @ (q_model.a / a_2_norm))[x]
        m_b = q_model.b_mag * (q_model.Phi @ (q_model.b / b_2_norm))[x]
        m_a_prime = q_model.a_mag * (q_model.Phi @ (q_model.a / a_2_norm))[xp]

        c = reweight.mean()
        expected_angle = (distp.probs * (m_b[:, None, :] * m_a_prime.reshape([state.shape[0], self.num_actions, -1])).sum(-1)).sum(-1)
        pop_loss = -(reweight / c).detach() * (2 * (1 - done.float()) * q_model.pop_gamma * expected_angle)

        policy_loss += beta * pop_loss.mean()
        self.optim.zero_grad()
        policy_loss.backward()
        self.optim.step()

        info = {
            'policy_loss': policy_loss.detach().cpu().numpy(),
            'pop_loss': pop_loss.mean().detach().cpu().numpy()
        }
        info.update(alpha_info)
        info.update(beta_info)
        return info


class POPLinearPolicy2(LinearPolicy):
    def __init__(self, observation_space, action_space, Phi, lr=0.01,
                 alpha_multiplier=1.0, use_automatic_entropy_tuning=False, target_entropy=0.1, alpha_lr=1e-2,
                 beta_multiplier=1.0, use_automatic_kl_tuning=False, target_kl=1.0, beta_lr=1e-2):
        super().__init__(observation_space, action_space, Phi, lr=lr,
                         alpha_multiplier=alpha_multiplier, use_automatic_entropy_tuning=use_automatic_entropy_tuning,
                         target_entropy=target_entropy, alpha_lr=alpha_lr)
        self.beta_multiplier = beta_multiplier
        self.use_automatic_kl_tuning = use_automatic_kl_tuning
        if self.use_automatic_kl_tuning:
            self.log_beta = torch.tensor(np.log(beta_multiplier), dtype=torch.float32, requires_grad=True)
            self.target_kl = target_kl
            self.beta_optim = torch.optim.Adam([self.log_beta], lr=beta_lr)
        
        self.model_tilde = torch.nn.Linear(self.feature_dim, 1, bias=False, device=self.Phi.device)
        self.optim_tilde = torch.optim.Adam(self.model_tilde.parameters(), lr=lr)

    def forward(self, state):
        bs = state.shape[0] if state.ndim > 0 else 1
        x = (state * self.num_actions).reshape([-1, 1])\
            + torch.arange(self.num_actions, dtype=torch.int64, device=state.device).reshape([1, -1])
        x = x.reshape([-1])
        return Categorical(logits=self.model_tilde(self.Phi[x]).reshape([bs, self.num_actions]))

    def update_beta(self, reweight):
        if self.use_automatic_kl_tuning:
            reweight = reweight / reweight.mean()
            beta_loss = -self.log_beta * (reweight * torch.log(reweight) - self.target_kl).mean().detach()
            self.beta_optim.zero_grad()
            beta_loss.backward()
            self.beta_optim.step()

            beta = torch.exp(self.log_beta) * self.beta_multiplier
            info = {
                'beta': beta.detach().cpu().numpy(),
                'beta_loss': beta_loss.detach().cpu().numpy(),
                'kl': (reweight * torch.log(reweight)).mean().detach().cpu().numpy(),
            }
        else:
            beta = self.beta_multiplier
            beta_loss = 0.0
            info = {}

        return beta, beta_loss, info

    def update(self, q_model, state, action, reward, state_prime, done, gamma=0.99, bc=False):
        x = state * self.num_actions + action
        reweight = q_model.reweight[x]
        bs = state.shape[0] if state.ndim > 0 else 1

        alpha, alpha_loss, alpha_info = self.update_alpha(self.forward(state))
        beta, beta_loss, beta_info = self.update_beta(reweight)

        policy_loss = self._policy_loss(q_model, state, action, reward, state_prime, done, gamma=gamma, alpha=alpha)
        self.optim.zero_grad()
        policy_loss.backward()
        self.optim.step()

        xp = (state_prime * self.num_actions).reshape([-1, 1])\
            + torch.arange(self.num_actions, dtype=torch.int64, device=state.device).reshape([1, -1])
        xp = xp.reshape([-1])
        distp = Categorical(logits=self.model_tilde(self.Phi[xp]).reshape([state.shape[0], self.num_actions]))
        m_b = (q_model.Phi @ q_model.b)[x]
        m_a_prime = (q_model.Phi @ q_model.a)[xp]

        c = reweight.mean()
        expected_angle = (distp.probs * (m_b[:, None, :] * m_a_prime.reshape([state.shape[0], self.num_actions, -1])).sum(-1)).sum(-1)
        
        x = (state * self.num_actions).reshape([-1, 1]) \
            + torch.arange(self.num_actions, dtype=torch.int64, device=state.device).reshape([1, -1])
        x = x.reshape([-1])
        dist = Categorical(logits=self.model(self.Phi[x]).reshape([bs, self.num_actions]))
        dist_tilde = Categorical(logits=self.model_tilde(self.Phi[x]).reshape([bs, self.num_actions]))
        
        pop_policy_loss = torch.distributions.kl_divergence(dist, dist_tilde) - beta * (reweight / c).detach() * (2 * expected_angle)
        self.optim_tilde.zero_grad()
        pop_policy_loss.mean().backward()
        self.optim_tilde.step()

        info = {
            'policy_loss': policy_loss.detach().cpu().numpy(),
            'pop_policy_loss': pop_policy_loss.mean().detach().cpu().numpy()
        }
        info.update(alpha_info)
        info.update(beta_info)
        return info


class EpsilonGreedyPolicy(Policy):
    def __init__(self, policy, action_space, epsilon=0.1):
        self.policy = policy
        self.epsilon = epsilon
        self.action_dim = action_space.n
    
    def act(self, state):
        if isinstance(state, int) or state.ndim == 0:
            if np.random.rand() < self.epsilon:
                return np.random.randint(0, self.action_dim)
            else:
                return self.policy.act(state)
        elif state.ndim > 1:
            raise ValueError("state must be 0D or 1D")
        else:
            bs = state.shape[0]
            return np.where(np.random.rand(bs) < self.epsilon, np.random.randint(0, self.action_dim, bs), self.policy.act(state))
    
    def dist(self, state):
        if isinstance(state, int) or state.ndim == 0:
            uniform_dist = np.ones(self.action_dim) / self.action_dim
        elif state.ndim == 1:
            uniform_dist = np.ones((state.shape[0], self.action_dim)) / self.action_dim
        else:
            raise ValueError("state must be 0D or 1D")
        
        return uniform_dist * self.epsilon + (1 - self.epsilon) * self.policy.dist(state)
    
    def update(self, q_model, state, action, reward, state_prime, done, gamma=0.99, bc=False):
        return self.policy.update(q_model, state, action, reward, state_prime, done, gamma=gamma, bc=bc)
