
import numpy as np
import torch
import torch.nn as nn
import typing as ty

from gym import spaces
# from gymnasium import spaces as spaces

from .base import init_linear_weights_by_xavier
from .base import MLP



class ActionValueNetwork4ContinuousAction(nn.Module):

    '''
        A function approximator for quantile-based distributional RL.
    '''

    def __init__(self,
                 observation_space, action_space,
                 n_hidden: list = [64],
                 layer_norm: bool = False,
                 device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        ):
        super(ActionValueNetwork4ContinuousAction, self).__init__()

        self.observation_space = observation_space
        # if isinstance(observation_space, spaces.Box): # or isinstance(observation_space, spaces2.Box):
        self.dim_ob = observation_space.shape[0]
        # else:
        #     raise NotImplementedError
        # assert isinstance(action_space, spaces.Box) # or isinstance(action_space, spaces2.Box)
        self.action_space = action_space
        self.dim_ac = action_space.shape[0]
        self.device = device
        self.n_hidden = n_hidden # used if representation=='NN'

        self.q_net = MLP(
            n_units=[self.dim_ob+self.dim_ac]+self.n_hidden+[1],
            activation=nn.ReLU, output_activation=nn.Identity,
            layer_norm=layer_norm,
        )
        # self.q_net.apply(lambda x: init_linear_weights_by_xavier(x, 3))


    def forward(self, observation: torch.tensor, action: torch.tensor, n_samples: int = 1) -> torch.tensor:
        batch_size = observation.shape[0]
        assert observation.shape[0] == action.shape[0]
        assert observation.shape == (batch_size, self.dim_ob)

        if n_samples > 1:
            assert action.shape == (batch_size, n_samples, self.dim_ac)
            observation = observation.unsqueeze(1).repeat(1, n_samples, 1)
            assert observation.dim() == 3
            q_sa_shape = (batch_size, n_samples, 1)

        else:
            assert action.shape == (batch_size, self.dim_ac)
            q_sa_shape = (batch_size, 1)

        q_sa = self.q_net(torch.cat([observation, action], dim=-1))
        # assert q_sa.shape == (batch_size, 1), f'{q_sa.shape} != {(batch_size, 1)}'
        assert q_sa.shape == q_sa_shape, f'{q_sa.shape} != {q_sa_shape}'
        q_sa = q_sa.squeeze(-1)
        return q_sa






class QuantileNetwork(nn.Module):

    '''
        A function approximator for quantile-based distributional RL.
    '''

    def __init__(
        self,
        observation_space, action_space,
        N: int,
        utility: str,
        alpha: ty.Union[float, str],
        tau: torch.tensor,
        tau_hat: torch.tensor,
        weight: torch.tensor,
        device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        representation: str = 'tabular',
        non_crossing: bool = False,
        n_hidden: list = [64],
        layer_norm: bool = False,
    ):
        super(QuantileNetwork, self).__init__()

        self.observation_space = observation_space
        if isinstance(observation_space, spaces.Discrete):
            self.dim_observation = observation_space.n
        elif isinstance(observation_space, spaces.Box):
            self.dim_observation = observation_space.shape[0]
        else:
            raise NotImplementedError

        assert isinstance(action_space, spaces.Discrete)
        self.action_space = action_space
        self.num_actions = action_space.n

        self.N = N # the number of atoms that represent quantiles
        self.device = device
        self.non_crossing = non_crossing
        self.n_hidden = n_hidden # used if representation=='NN'

        self.representation = representation
        if self.representation == 'tabular':
            assert isinstance(observation_space, spaces.Discrete) and isinstance(action_space, spaces.Discrete)
            self.q_net = nn.Linear(self.dim_observation, self.num_actions * self.N)
        elif self.representation == 'linear':
            self.q_net = nn.Linear(self.dim_observation, self.num_actions * self.N)
        elif self.representation == 'NN':
            # self.q_net = nn.Sequential(
            #     nn.Linear(self.dim_observation, self.n_hidden),
            #     nn.ReLU(),
            #     nn.Linear(self.n_hidden, self.num_actions * self.N),
            # )
            self.q_net = MLP(
                n_units=[self.dim_observation]+self.n_hidden+[self.num_actions*self.N],
                activation=nn.ReLU, output_activation=nn.Identity,
                layer_norm=layer_norm,
            )
        else:
            raise NotImplementedError

        # self.q_net.apply(lambda x: init_linear_weights_by_xavier(x, 3))

        self.utility = utility
        if alpha == 'inf':
            alpha = np.inf
        elif alpha == '-inf':
            alpha = - np.inf
        self.alpha = alpha

        self.tau = tau
        self.tau_hat = tau_hat
        self.weight = weight

        if self.non_crossing:
            ''' scale factor network '''
            if self.representation == 'NN':
                # self.scale_factor = nn.Sequential(
                #     nn.Linear(self.dim_observation, n_hidden),
                #     nn.ReLU(),
                #     nn.Linear(n_hidden, self.num_actions * 2),
                # )
                self.q_net = MLP(
                    n_units=[self.dim_observation]+n_hidden+[self.num_actions * 2],
                    activation=nn.ReLU, output_activation=nn.Identity,
                    layer_norm=layer_norm,
                )
            else:
                self.scale_factor = nn.Linear(self.dim_observation, self.num_actions * 2)
            # self.scale_factor.apply(lambda x: init_linear_weights_by_xavier(x, 3))


    # def set_alpha(self, alpha):
    #     if alpha == 'inf':
    #         alpha = np.inf
    #     elif alpha == '-inf':
    #         alpha = - np.inf
    #     self.alpha = alpha


    def forward(self, observation: torch.tensor) -> torch.tensor:
        batch_size = observation.shape[0]

        if self.representation == 'tabular':
            observation = nn.functional.one_hot(observation.long(), num_classes=self.dim_observation)

        if self.non_crossing:
            ''' calculation of ordered & normalized quantiles, Eq. 18 '''
            detour = self.q_net(observation.float()).view(batch_size, self.N, self.num_actions)
            log_prob = nn.functional.log_softmax(detour, dim=1)
            prob = torch.exp(log_prob)
            ordered = torch.cumsum(prob, dim=1)
            assert np.abs((ordered[:,-1,:].sum() - batch_size * self.num_actions).detach().numpy()) < 0.01, \
                    f'{ordered[:,-1,:].sum()}, {batch_size * self.num_actions}'

            '''
                scaling of o&n quantiles, Eq. 19.
                softplus is used instead of ReLU
            '''
            scale_factor = self.scale_factor(observation.float()).view(batch_size, 1, self.num_actions * 2)
            scaler = nn.functional.softplus(scale_factor[:, :, :self.num_actions])
            offset = scale_factor[:, :, self.num_actions:]
            assert scaler.shape == offset.shape

            quantile = scaler * ordered + offset
            # entropy = - (log_prob * prob).sum(dim=-1, keepdim=True)

        else:
            quantile = self.q_net(observation.float()).view(batch_size, self.N, self.num_actions)
            # entropy = torch.zeros((batch_size, 1), device=self.device, dtype=torch.float32)
            # entropy = torch.zeros((batch_size, 1))

        # quantile = self.q_net(observation.float()).view(batch_size, self.N, self.num_actions)
        assert quantile.shape == (batch_size, self.N, self.num_actions)

        return quantile



    def compute_utility(
        self,
        observation: torch.tensor = None,
        quantile: torch.tensor = None,
        # weight: torch.tensor = None, # dummy
        utility: str = None,
     ) -> torch.tensor:

        if quantile is None:
            assert observation is not None
            batch_size = observation.shape[0]
            quantile = self.forward(observation=observation)
        if observation is None:
            assert quantile is not None
            batch_size = quantile.shape[0]
        assert quantile.shape == (batch_size, self.N, self.num_actions)

        if utility == 'mean':
            return self.compute_mean(quantile)
        elif utility == 'erm':
            return self.compute_erm(quantile)
        elif utility == 'cvar':
            return self.compute_cvar(quantile)
        else:
            raise ValueError


    def compute_mean(
        self,
         # observation: torch.tensor = None,
         quantile: torch.tensor = None,
     ) -> torch.tensor:
        # q = quantile.mean(dim=1)
        q = (quantile * self.weight.unsqueeze(0).unsqueeze(2)).sum(dim=1, keepdim=False)

        batch_size = quantile.shape[0]
        assert q.shape == (batch_size, self.num_actions), f'{q.shape} != ({batch_size}, {self.num_actions})'
        return q


    def compute_erm(
        self,
        # observation: torch.tensor = None,
        quantile: torch.tensor = None,
        alpha: float = None
    ) -> torch.tensor:
        if alpha is None:
            alpha = self.alpha

        if alpha == 0:
            erm = self.compute_mean(quantile=quantile)
        elif alpha == np.inf:
            erm = torch.max(quantile, dim=1, keepdim=False)[0]
        elif alpha == -np.inf:
            erm = torch.min(quantile, dim=1, keepdim=False)[0]
        else:
            if alpha > 0:
                m = torch.max(quantile, dim=1, keepdim=True)[0]
            elif alpha < 0:
                m = torch.min(quantile, dim=1, keepdim=True)[0]
            Z = torch.exp(alpha * (quantile - m)) * self.weight.unsqueeze(0).unsqueeze(2)
            erm = m.squeeze() + torch.log(Z.sum(dim=1)) / alpha

        batch_size = quantile.shape[0]
        assert erm.shape == (batch_size, self.num_actions)
        return erm


    def compute_cvar(
        self,
        # observation: torch.tensor = None,
        quantile: torch.tensor = None,
        # weight: torch.tensor = None, # dummy
        alpha: float = None,
    ) -> torch.tensor:
        if alpha is None:
            alpha = self.alpha
        assert alpha >= 0. and alpha <= 1.

        tau = torch.cumsum(self.weight.unsqueeze(0).unsqueeze(2), dim=1)
        risk_weight = (1. / alpha) * (tau < alpha)
        risk_weight = risk_weight.clamp(0., 5.)
        q = (quantile * risk_weight * self.weight.unsqueeze(0).unsqueeze(2)).sum(dim=1, keepdim=False)

        batch_size = quantile.shape[0]
        assert q.shape == (batch_size, self.num_actions), f'{q.shape} != ({batch_size})'
        return q



class QuantileNetwork4ContinuousAction(nn.Module):

    '''
        A function approximator for quantile-based distributional RL.
    '''

    def __init__(self,
                 observation_space, action_space,
                 N: int,
                 utility: str,
                 alpha: ty.Union[float, str],
                 tau: torch.tensor,
                 tau_hat: torch.tensor,
                 weight: torch.tensor,
                 device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
                 representation: str = 'tabular',
                 non_crossing: bool = False,
                 n_hidden: list = [64],
                 layer_norm: bool = False,
        ):
        super(QuantileNetwork4ContinuousAction, self).__init__()

        self.observation_space = observation_space
        # if isinstance(observation_space, spaces.Discrete):
        #     self.dim_observation = observation_space.n
        # el
        if isinstance(observation_space, spaces.Box):
            self.dim_observation = observation_space.shape[0]
        else:
            raise NotImplementedError

        assert isinstance(action_space, spaces.Box)
        self.action_space = action_space
        self.dim_action = action_space.shape[0]
        # self.action_space = action_space
        # if isinstance(action_space, spaces.Discrete):
        #     self.dim_action = action_space.n
        # elif isinstance(action_space, spaces.Box):
        #     self.dim_action = action_space.shape[0]
        # else:
        #     raise NotImplementedError

        self.N = N # the number of atoms that represent quantiles
        self.device = device
        self.non_crossing = non_crossing
        self.n_hidden = n_hidden # used if representation=='NN'

        self.representation = representation
        # if self.representation == 'tabular':
        #     assert isinstance(observation_space, spaces.Discrete) and isinstance(action_space, spaces.Discrete)
        #     self.q_net = nn.Linear(self.dim_observation + self.dim_action, self.N)
        # el
        if self.representation == 'linear':
            self.q_net = nn.Linear(self.dim_observation + self.dim_action, self.N)
        elif self.representation == 'NN':
            # self.q_net = nn.Sequential(
            #     nn.Linear(self.dim_observation + self.dim_action, self.n_hidden),
            #     nn.ReLU(),
            #     nn.Linear(self.n_hidden, self.N),
            # )
            self.q_net = MLP(
                n_units=[self.dim_observation+self.dim_action]+self.n_hidden+[self.N],
                activation=nn.ReLU, output_activation=nn.Identity,
                layer_norm=layer_norm,
            )
        else:
            raise NotImplementedError

        # self.q_net.apply(lambda x: init_linear_weights_by_xavier(x, 3))

        self.utility = utility
        if alpha == 'inf':
            alpha = np.inf
        elif alpha == '-inf':
            alpha = - np.inf
        self.alpha = alpha

        self.tau = tau
        self.tau_hat = tau_hat
        self.weight = weight

        if self.non_crossing:
            ''' scale factor network '''
            if self.representation == 'NN':
                # self.scale_factor = nn.Sequential(
                #     nn.Linear(self.dim_observation + self.dim_action, n_hidden),
                #     nn.ReLU(),
                #     nn.Linear(n_hidden, 2),
                # )
                self.scale_factor = MLP(
                    n_units=[self.dim_observation+self.dim_action]+n_hidden+[2],
                    activation=nn.ReLU, output_activation=nn.Identity,
                    layer_norm=layer_norm,
                )
            else:
                self.scale_factor = nn.Linear(self.dim_observation + self.dim_action, 2)
            # self.scale_factor.apply(lambda x: init_linear_weights_by_xavier(x, 3))


    def forward(self, observation: torch.tensor, action: torch.tensor) -> torch.tensor:
        assert observation.shape[0] == action.shape[0]
        batch_size = observation.shape[0]

        # if self.representation == 'tabular':
        #     observation = nn.functional.one_hot(observation.long(), num_classes=self.dim_observation)

        if self.non_crossing:
            ''' calculation of ordered & normalized quantiles, Eq. 18 '''
            detour = self.q_net(torch.cat([observation, action], dim=-1)).view(batch_size, self.N)
            log_prob = nn.functional.log_softmax(detour, dim=1)
            prob = torch.exp(log_prob)
            ordered = torch.cumsum(prob, dim=1)
            assert np.abs((ordered[:,-1].sum() - batch_size).detach().numpy()) < 0.01, \
                    f'{ordered[:,-1].sum()}, {batch_size}'
            assert ordered.shape == (batch_size, self.N)

            '''
                scaling of o&n quantiles, Eq. 19.
                softplus is used instead of ReLU
            '''
            scale_factor = self.scale_factor(torch.cat([observation, action], dim=-1)).view(batch_size, 1, 2)
            scaler = nn.functional.softplus(scale_factor[:, :, 0])
            offset = scale_factor[:, :, 1]
            assert scaler.shape == offset.shape
            # assert ordered.shape == scaler.shape and ordered.shape == offset.shape, \
            #     f'ordered.shape={ordered.shape}, scaler.shape={scaler.shape}, offset.shape={offset.shape}'

            quantile = scaler * ordered + offset
        else:
            quantile = self.q_net(torch.cat([observation, action], dim=-1)).view(batch_size, self.N)

        # assert quantile.shape == (batch_size, self.N, self.num_actions)
        assert quantile.shape == (batch_size, self.N)

        return quantile



    def compute_utility(
        self,
        observation: torch.tensor = None,
        action: torch.tensor = None,
        quantile: torch.tensor = None,
        weight: torch.tensor = None, # dummy
        utility: str = None,
     ) -> torch.tensor:

        if quantile is None:
            assert observation is not None and action is not None
            assert observation.shape[0] == action.shape[0]
            batch_size = observation.shape[0]
            quantile = self.forward(observation=observation, action=action)
        if observation is None:
            assert quantile is not None
            batch_size = quantile.shape[0]
        assert quantile.shape == (batch_size, self.N)
        if utility is None:
            utility = self.utility

        if utility == 'mean':
            return self.compute_mean(quantile, weight)
        elif utility == 'erm':
            return self.compute_erm(quantile, weight)
        elif utility == 'cvar':
            return self.compute_cvar(quantile, weight)
        else:
            raise ValueError


    def compute_mean(
        self,
        quantile: torch.tensor = None,
        weight: torch.tensor = None, # dummy
    ) -> torch.tensor:
        batch_size = quantile.shape[0]
        assert quantile.shape == (batch_size, self.N)

        q = (quantile * self.weight.unsqueeze(0)).sum(dim=1, keepdim=False)

        assert q.shape == (batch_size,), f'{q.shape} != ({batch_size})'
        return q


    def compute_erm(
        self,
        quantile: torch.tensor = None,
        weight: torch.tensor = None, # dummy
        alpha: float = None,
    ) -> torch.tensor:
        if alpha is None:
            alpha = self.alpha

        '''
            Remark: torch.max(**) -> (values: Tensor, indexes: LongTensor)
        '''
        if alpha == 0:
            erm = self.compute_mean(quantile=quantile)
        elif alpha == np.inf:
            erm = torch.max(quantile, dim=1, keepdim=False)[0]
        elif alpha == -np.inf:
            erm = torch.min(quantile, dim=1, keepdim=False)[0]
        else:
            if alpha > 0:
                m = torch.max(quantile, dim=1, keepdim=True)[0]
            elif alpha < 0:
                m = torch.min(quantile, dim=1, keepdim=True)[0]
            Z = torch.exp(alpha * (quantile - m)) * self.weight.unsqueeze(0)
            erm = m.squeeze() + torch.log(Z.sum(dim=1)) / alpha

        batch_size = quantile.shape[0]
        assert erm.shape == (batch_size,)
        return erm


    def compute_cvar(
        self,
        quantile: torch.tensor = None,
        weight: torch.tensor = None, # dummy
        alpha: float = None,
    ) -> torch.tensor:
        if alpha is None:
            alpha = self.alpha
        assert alpha >= 0. and alpha <= 1.

        tau = torch.cumsum(self.weight.unsqueeze(0), dim=1)
        risk_weight = (1. / alpha) * (tau < alpha)
        risk_weight = risk_weight.clamp(0., 5.)

        q = (quantile * risk_weight * self.weight.unsqueeze(0)).sum(dim=1, keepdim=False)

        batch_size = quantile.shape[0]
        assert q.shape == (batch_size,), f'{q.shape} != ({batch_size})'
        return q




class SingleHeadQuantileNetwork4ContinuousAction(nn.Module):

    def __init__(
        self,
        observation_space, action_space,
        N: int = 32,
        utility: str='mean',
        alpha: ty.Union[float, str] = 0.,
        # tau: torch.tensor,
        # tau_hat: torch.tensor,
        # weight: torch.tensor,
        # representation: str = 'tabular',
        fraction_proposal: str = 'fix',
        non_crossing: bool = False,
        n_hidden: list = [64, 64],
        n_embedding: int = 64,
        layer_norm: bool = False,
        device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        # self,
        # hidden_sizes,
        # output_size,
        # input_size,
        # n_embedding=64,
        # num_quantiles=32,
        # layer_norm=True,
        # device='cpu',
        # **kwargs,
    ):

        '''
        num_inputs = env.observation_space.shape[0]
        self.zf1 = QuantileMlp(input_size=num_inputs+action_space.shape[0],
                          output_size=1,
                          num_quantiles=num_quantiles,
                          hidden_sizes=[hidden_size, hidden_size]).to(self.device)
        '''
        super().__init__()
        assert isinstance(observation_space, spaces.Box)
        assert isinstance(action_space, spaces.Box)

        self.observation_space = observation_space
        self.action_space = action_space

        self.dim_observation = observation_space.shape[0]
        self.dim_action = action_space.shape[0]
        self.input_size = self.dim_observation + self.dim_action

        self.layer_norm = layer_norm

        self.utility = utility
        if alpha == 'inf':
            alpha = np.inf
        elif alpha == '-inf':
            alpha = - np.inf
        self.alpha = alpha

        self.N = N # the number of atoms that represent quantiles
        self.fraction_proposal = fraction_proposal
        self.non_crossing = non_crossing
        self.n_hidden = n_hidden # used if representation=='NN'
        # self.q_net = MLP(
        #     n_units=[self.dim_observation+self.dim_action]+self.n_hidden+[self.N],
        #     activation=nn.ReLU, output_activation=nn.Identity,
        #     layer_norm=layer_norm,
        # )
        # self.q_net.apply(lambda x: init_linear_weights_by_xavier(x, 3))
        self.n_embedding = n_embedding
        self.device = device

        self.base_fc = []
        last_size = self.input_size
        for next_size in self.n_hidden[:-1]:
            self.base_fc += [
                nn.Linear(last_size, next_size),
                nn.LayerNorm(next_size) if layer_norm else nn.Identity(),
                nn.ReLU(inplace=True),
            ]
            last_size = next_size
        self.base_fc = nn.Sequential(*self.base_fc)

        # self.integers = torch.from_numpy(np.arange(1, 1 + self.n_embedding)).float().to(device)
        self.integers = torch.as_tensor(
            np.arange(1, 1 + self.n_embedding), dtype=torch.float32, device=self.device
        )
        self.tau_fc = nn.Sequential(
            nn.Linear(n_embedding, last_size),
            nn.LayerNorm(last_size) if layer_norm else nn.Identity(),
            nn.Sigmoid(),
        )

        self.merge_fc = nn.Sequential(
            nn.Linear(last_size, self.n_hidden[-1]),
            nn.LayerNorm(self.n_hidden[-1]) if layer_norm else nn.Identity(),
            nn.ReLU(inplace=True),
        )
        self.last_fc = nn.Linear(self.n_hidden[-1], 1)


    def forward(
        self,
        observation: torch.tensor, action: torch.tensor,
        tau_hat: torch.tensor = None
    ) -> torch.tensor:
        assert observation.shape[0] == action.shape[0]
        batch_size = observation.shape[0]

        if tau_hat is None:
            tau, tau_hat, weight = self.get_quantile_fractions(observation, action)
            # print(f'hoge.shape={hoge.shape}')
            # print(f'tau.shape={tau.shape}')
            # print(f'tau_hat.shape={tau_hat.shape}')
            # print(f'weight.shape={weight.shape}')
            assert tau_hat.shape == (batch_size, self.N)

        h_x = torch.cat([observation, action], dim=1)
        h_x = self.base_fc(h_x)
        # print(f'base_fc: h_x.shape={h_x.shape}')
        assert h_x.shape == (batch_size, self.n_hidden[-2])

        h_t = torch.cos(tau_hat.unsqueeze(-1) * self.integers * np.pi)  # (N, T, E)
        # print(f'cos: h_t.shape={h_t.shape}')
        assert h_t.shape == (batch_size, self.N, self.n_embedding)

        h_t = self.tau_fc(h_t)
        # print(f'tau_fc: h_t.shape={h_t.shape}')
        assert h_t.shape == (batch_size, self.N, self.n_hidden[-2])

        h = torch.mul(h_t, h_x.unsqueeze(-2))
        # print(f'mul: h.shape={h.shape}')
        assert h.shape == (batch_size, self.N, self.n_hidden[-2])

        h = self.merge_fc(h)
        # print(f'merge_fc: h.shape={h.shape}')
        assert h.shape == (batch_size, self.N, self.n_hidden[-1])

        output = self.last_fc(h).squeeze(-1)
        # print(f'output.shape={output.shape}')
        assert output.shape == (batch_size, self.N)

        return output


    def get_quantile_fractions(self, observation, action, fp=None):
        assert observation.shape[0] == action.shape[0]
        batch_size = observation.shape[0]

        if self.fraction_proposal == 'fix':
            weight = torch.zeros(batch_size, self.N, dtype=torch.float32, device=self.device) + 1. / self.N
        elif self.fraction_proposal == 'iqn':
            ''' add 0.1 to prevent tau getting too close '''
            # weight = torch.rand(batch_size, self.N, ) + 0.1 # original
            weight = torch.rand(batch_size, self.N, dtype=torch.float32, device=self.device) + 0.1
            weight /= weight.sum(dim=-1, keepdims=True)
        elif self.fraction_proposal == 'fqf':
            raise NotImplementedError
            if fp is None:
                fp = self.fp
            weight = fp(observation, action)
        '''
            \tau_1 ... \tau_N.
            \tau_0 in the paper is omitted; \tau_i is actually not used.
        '''
        tau = torch.cumsum(weight, dim=1)
        with torch.no_grad():
            tau_hat = torch.zeros_like(tau, dtype=torch.float32, device=self.device)
            tau_hat[:, 0:1] = tau[:, 0:1] / 2.
            tau_hat[:, 1:] = (tau[:, 1:] + tau[:, :-1]) / 2.
        return tau, tau_hat, weight


    def compute_utility(
        self,
        observation: torch.tensor = None,
        action: torch.tensor = None,
        quantile: torch.tensor = None,
        weight: torch.tensor = None, # dummy
        utility: str = None,
     ) -> torch.tensor:

        if quantile is None:
            assert observation is not None and action is not None
            assert observation.shape[0] == action.shape[0]
            batch_size = observation.shape[0]
            tau, tau_hat, weight \
                = self.get_quantile_fractions(observation=observation, action=action)
            quantile = self.forward(observation=observation, action=action, tau_hat=tau_hat)
        if observation is None:
            assert quantile is not None and weight is not None
            batch_size = quantile.shape[0]
        assert quantile.shape == (batch_size, self.N)
        assert weight.shape == (batch_size, self.N)

        if utility == 'mean':
            return self.compute_mean(quantile, weight)
        elif utility == 'erm':
            return self.compute_erm(quantile, weight)
        elif utility == 'cvar':
            return self.compute_cvar(quantile, weight)
        else:
            raise ValueError


    def compute_mean(
        self,
        # observation: torch.tensor = None,
        # action: torch.tensor = None,
        quantile: torch.tensor = None,
        weight: torch.tensor = None,
    ) -> torch.tensor:

        q = (quantile * weight).sum(dim=1, keepdim=False)

        batch_size = quantile.shape[0]
        assert q.shape == (batch_size,), f'{q.shape} != ({batch_size})'
        return q


    def compute_erm(
        self,
        # observation: torch.tensor = None,
        # action: torch.tensor = None,
        quantile: torch.tensor = None,
        weight: torch.tensor = None,
        alpha: float = None
    ) -> torch.tensor:

        if alpha is None:
            alpha = self.alpha

        '''
            Remark: torch.max(**) -> (values: Tensor, indexes: LongTensor)
        '''
        if alpha == 0:
            erm = self.compute_mean(quantile=quantile)
        elif alpha == np.inf:
            erm = torch.max(quantile, dim=1, keepdim=False)[0]
        elif alpha == -np.inf:
            erm = torch.min(quantile, dim=1, keepdim=False)[0]
        else:
            if alpha > 0:
                m = torch.max(quantile, dim=1, keepdim=True)[0]
            elif alpha < 0:
                m = torch.min(quantile, dim=1, keepdim=True)[0]
            Z = torch.exp(alpha * (quantile - m)) * weight
            erm = m.squeeze() + torch.log(Z.sum(dim=1)) / alpha

        batch_size = quantile.shape[0]
        assert erm.shape == (batch_size,)
        return erm



    def compute_cvar(
        self,
        # observation: torch.tensor = None,
        # action: torch.tensor = None,
        quantile: torch.tensor = None,
        weight: torch.tensor = None, # dummy
        alpha: float = None,
    ) -> torch.tensor:
        if alpha is None:
            alpha = self.alpha
        assert alpha >= 0. and alpha <= 1.

        tau = torch.cumsum(weight, dim=1)
        risk_weight = (1. / alpha) * (tau < alpha)
        risk_weight = risk_weight.clamp(0., 5.)
        q = (quantile * risk_weight * weight).sum(dim=1, keepdim=False)

        batch_size = quantile.shape[0]
        assert q.shape == (batch_size,), f'{q.shape} != ({batch_size})'
        return q
