import numpy as np
import torch
from torch import nn as nn

from lfrl.torch.networks import FlattenMlp
import lfrl.torch.pytorch_util as ptu


class QRNetwork(FlattenMlp):
    """
    Quantile regression critic network for continuous distributional RL.
    """

    def __init__(
            self,
            obs_dim,
            action_dim,
            n_quantiles=32,
            **kwargs
    ):
        super().__init__(
            input_size=obs_dim + action_dim,
            output_size=n_quantiles,
            **kwargs
        )

        self.output_size = 1
        self.n_quantiles = n_quantiles

        self.thresholds = ptu.from_numpy(
            np.linspace(0, 1, n_quantiles + 2)[1:n_quantiles + 1]
        )

        self.mask = ptu.ones(1, self.n_quantiles)
        self.weight = self.n_quantiles

    def get_quantile_values(self, *inputs, **kwargs):
        return super().forward(*inputs, **kwargs)

    def forward(self, *inputs, **kwargs):
        quantile_values = self.get_quantile_values(*inputs, **kwargs)
        return torch.sum(quantile_values * self.mask, dim=-1) / self.weight


class QRNetworkAdaptive(FlattenMlp):

    def __init__(
            self,
            obs_dim,
            action_dim,
            n_quantiles=32,
            n_parameter_bins=100,
            **kwargs
    ):
        super().__init__(
            input_size=obs_dim + action_dim + 1,  # latent is 1D
            output_size=n_quantiles,
            **kwargs
        )

        self.output_size = 1
        self.n_quantiles = n_quantiles
        self.n_parameter_bins = n_parameter_bins

        self.thresholds = ptu.from_numpy(
            np.linspace(0, 1, n_quantiles + 2)[1:n_quantiles + 1]
        )

        self.masks = ptu.ones(n_parameter_bins, 1, self.n_quantiles)
        self.weights = ptu.ones(n_parameter_bins) * self.n_quantiles

        self.set_masks()

    def get_quantile_values(self, *inputs, **kwargs):
        return super().forward(*inputs, **kwargs)

    def forward(self, *inputs, risk_parameters=None, **kwargs):
        quantile_values = self.get_quantile_values(*inputs, **kwargs)

        if risk_parameters is None:
            risk_parameters = 0.5 * ptu.ones(quantile_values.shape[0], self.n_quantiles)

        risk_params_int = torch.floor((self.n_parameter_bins * risk_parameters)).long()
        risk_ids = torch.clamp(risk_params_int, 0, self.n_parameter_bins - 1)

        masks = self.masks[risk_ids]
        weights = self.weights[risk_ids]

        return torch.sum(quantile_values * masks, dim=-1) / weights

    def set_masks(self):
        tau = np.linspace(0, 1, self.n_quantiles + 1)
        for k in range(self.n_parameter_bins):
            betas = np.zeros(self.n_quantiles + 1)
            for i in range(self.n_quantiles + 1):
                betas[i] = self.inverse_beta_func(tau[i], param=k / (self.n_parameter_bins - 1))
            for i in range(self.n_quantiles):
                self.masks[k, 0, i] = betas[i + 1] - betas[i]
            self.weights[k] = 1.

    def inverse_beta_func(self, tau, param=0.5):
        # tau refers to CDF of input
        # identity is equivalent to risk-neutral
        return tau


def get_inverse(func, x, n_bins=1024, **kwargs):
    # assumes domain/range is (0, 1), and function is monotonically increasing

    # assume we don't need things finer than 1024 for now, just
    # going to use a slow linear search
    for i in range(n_bins):
        new_val = func(i / n_bins, **kwargs)
        if x <= new_val:
            return i / n_bins
    return 1.


class QRNetworkRisk(QRNetwork):

    def set_mask(self):
        tau = np.linspace(0, 1, self.n_quantiles + 1)
        betas = np.zeros(self.n_quantiles + 1)
        for i in range(self.n_quantiles + 1):
            betas[i] = self.inverse_beta_func(tau[i])

        for i in range(self.n_quantiles):
            self.mask[0, i] = betas[i + 1] - betas[i]

        self.weight = 1.

    def inverse_beta_func(self, tau):
        # tau refers to CDF of input
        # identity is equivalent to risk-neutral
        return tau


class QRNetworkCVaRAdaptive(QRNetworkAdaptive):

    def inverse_beta_func(self, tau, param=0.5):
        if tau < param:
            return tau / param
        else:
            return 1.


class QRNetworkCVaR(QRNetwork):

    def __init__(self, alpha=0.5, mode='averse', **kwargs):
        super().__init__(**kwargs)

        self.alpha = alpha
        self.mode = mode  # 'averse' or 'seeking'

        kept_quantiles = int(self.alpha * self.n_quantiles)
        if mode == 'averse':
            self.mask[0, kept_quantiles:] = 0
        elif mode == 'seeking':
            self.mask[0, :kept_quantiles] = 0
        else:
            raise NameError('mask mode not recognized')

        self.weight = kept_quantiles


def cumulative_probability_weighting(tau, eta):
    # from cumulative prospect theory
    return (tau ** eta) / (((tau ** eta) + (1 - tau) ** eta) ** (1 / eta))


class QRNetworkCPW(QRNetworkRisk):

    def __init__(self, eta=0.71, **kwargs):
        super().__init__(**kwargs)
        self.eta = eta
        self.set_mask()

    def inverse_beta_func(self, tau):
        return get_inverse(
            cumulative_probability_weighting,
            tau,
            eta=self.eta,
        )