import torch
import torch.nn as nn
import torch.nn.functional as F

from .base_vectorised import BaseVectorisedNetwork, VectorisedLinear

import time
class BaseCriticNetwork(BaseVectorisedNetwork):


    def __init__(self, obs_dims, model_info, ensemble_num=1, **kwargs):

    
        super().__init__(obs_dims=obs_dims, model_info=model_info,
                            ensemble_num=ensemble_num, **kwargs)

        self.critic_factor = kwargs.get('critic_factor')
        self.is_critic = True
        self.final_activation= getattr(torch.nn,model_info['critic_final_activation'],None)
        self.action_value_model = self.construct_model(model_info, add_final=True)

    def forward(self, state, action=None):

        if action is not None:
            x = torch.cat([state,action],-1)
        else:
            x = state
        critic_values = self.action_value_model(x)
        return critic_values



class DiscreteVectorisedCritic(BaseCriticNetwork):

    ''' Discrete critic produces a set of action values for all possible actions '''

    def __init__(self, obs_dims, n_actions, model_info, ensemble_num=1, **kwargs):

        self.final_layer_dim = n_actions

        super().__init__(obs_dims=obs_dims, model_info=model_info, ensemble_num=ensemble_num)

    def forward(self, state):

        x = super().forward(state)
        action_values = self.action_value_model(x)

        return action_values


class ContinuousVectorisedCritic(BaseCriticNetwork):

    ''' Continuous critic produces an output for a given state action pair '''

    def __init__(self, obs_dims, action_dims, model_info, ensemble_num=1, **kwargs):

        self.final_layer_dim = 1

        super().__init__(obs_dims=obs_dims, model_info=model_info,
                            ensemble_num=ensemble_num, action_dims=action_dims, **kwargs)


    def forward(self, state, action):

        values_list = super().forward(state,action)

        return values_list

class VectorisedStateValue(BaseCriticNetwork):
    ''' Value function produces a value for the state regardless of action chosen'''

    
    def __init__(self, obs_dims, model_info, ensemble_num=1, **kwargs):

        self.final_layer_dim = 1

        super().__init__(obs_dims=obs_dims, model_info=model_info,
                            ensemble_num=ensemble_num, **kwargs)


    def forward(self, state):

        state_value = super().forward(state)

        return state_value


class Revalued(BaseCriticNetwork):

    def __init__(self, obs_dims, model_info, action_dims, **kwargs):

        self.final_layer_dim = action_dims*obs_dims

        self.out_dims = action_dims
        self.sub_state_dims = obs_dims
        super().__init__(obs_dims=obs_dims, model_info=model_info, **kwargs)


    def forward(self, state):
        
        state_value = super().forward(state)

        state_value =  state_value.reshape(-1,self.critic_factor,self.sub_state_dims,self.out_dims)

        return state_value
