import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributions import Categorical, Normal
from .base_vectorised import BaseVectorisedNetwork, VectorisedLinear

from utils.decomp_networks import DecoupledQNetwork

class BaseVectorisedActor(BaseVectorisedNetwork):
    def __init__(self, obs_dims, model_info, ensemble_num=1, **kwargs):

        self.is_actor = True
        self.final_activation = None

        super().__init__(obs_dims=obs_dims, model_info=model_info,
                            ensemble_num=ensemble_num,**kwargs)

class DiscreteVectorisedActor(BaseVectorisedActor):

    ''' Discrete actor produces a categorical distribution assigns a probability to each action that can be taken '''
    def __init__(self, obs_dims, action_dims, n_actions, model_info, ensemble_num=1, **kwargs):

        self.final_layer_dim = n_actions
        self.type = 'Discrete'

        super().__init__(obs_dims=obs_dims, model_info=model_info,
                            ensemble_num=ensemble_num,**kwargs)

        
        self.model = DecoupledQNetwork(state_dim=obs_dims,
                                    hidden_dim=model_info['layers'][0],
                                    num_states=n_actions,
                                    num_heads=action_dims)



    def forward(self, state):



        action_logits = self.model(state)
        return {'action':action_logits} 


class DDPGVectorisedActor(BaseVectorisedActor):

    ''' DDPG vectorised actor generates a single action given an observation 'deterministic stochastic policy' '''
    def __init__(self, obs_dims, action_dims, model_info, ensemble_num=1, **kwargs):

        self.final_layer_dim = action_dims


        super().__init__(obs_dims=obs_dims, model_info=model_info,
                            ensemble_num=ensemble_num,**kwargs)



        self.final_activation = nn.Tanh
        self.policy = self.construct_model(model_info, add_final=True)
        self.max_val = torch.tensor(kwargs['max_val'])
        self.min_val = torch.tensor(kwargs['min_val'])


    def forward(self, state, **kwargs):

        x = self.policy(state)

        action_info = {}
        action_info['action'] = x*self.max_val
        return action_info
