
import numpy as np
import torch
import torch.nn as nn
import typing as ty

from gym import spaces
# from gymnasium import spaces2


def check_param_norm(models: ty.List[torch.nn.Module]):
    print(f'----check_param_norm----')
    for model in models:
        norm = 0
        for p in model.parameters():
            norm += torch.sum(p**2).item()
        print(f'--  {norm}')
        # print(f'{norm}')


def check_gradient(models: ty.List[torch.nn.Module]):
    print(f'----gradient----')
    for model in models:
        # for p, n in zip(model.parameters(), model._all_weights[0]):
        for p in model.parameters():
            # print(f'===========')
            # print(f'gradient:{n}')
            print(f'--  {p.grad.sum()}')
            # print(f'{p.grad}')
            # print(f'{p.grad.sum()}')


def init_linear_weights_by_xavier(module, gain):
    if type(module) == torch.nn.Linear:
        torch.nn.init.xavier_normal_(module.weight, gain)


def polyak_update(module_from: torch.nn.Module, module_to: torch.nn.Module, polyak: float):
    for p_from, p_to in zip(module_from.parameters(), module_to.parameters()):
        p_to.data.mul_(polyak)
        p_to.data.add_((1 - polyak) * p_from.data)


def apply_along_axis(func, x: torch.tensor, axis: int=0, arg_dict: dict={}):
    return torch.stack([
        func(x_i, **arg_dict) for x_i in torch.unbind(x, dim=axis)
    ], dim=axis)


def domain_bounding(
        quantile: torch.tensor,
        z_min: float, z_max: float,
    ):
        '''
            Must be used in "with torch.no_grad()".
        '''
        assert z_min < z_max
        q_min = quantile.min()
        q_max = quantile.max()

        # print('===============')
        # print('domain_bounding')
        # print(z_min)
        # print(z_max)
        # print(q_min)
        # print(q_max)

        if q_min == q_max:
            ''' concentrated '''
            if z_min <= q_min and q_min <= z_max:
                return quantile
            elif q_min <= z_min:
                return quantile + z_min - q_min
            elif z_max <= q_min:
                return quantile + z_max - q_min
            else:
                ''' this should not be the case '''
                raise ValueError
        else:
            if z_min <= q_min and q_max <= z_max:
                # print('do nothing')
                return quantile
            elif z_min <= q_min and q_min <= z_max and z_max <= q_max:
                ''' right tail over-estimated '''
                c = (z_max - q_min) / (q_max - q_min)
                # print('right tail over-estimated')
                # print(c)
                return (quantile - q_min) * c + q_min
            elif q_min <= z_min and z_min <= q_max and q_max <= z_max:
                ''' left tail under-estimated '''
                c = (q_max - z_min) / (q_max - q_min)
                # print('left tail under-estimated')
                # print(c)
                return (quantile - q_max) * c + q_max
            elif q_min <= z_min and z_max <= q_max:
                ''' left tail under-estimated & right tail over-estimated '''
                c = (z_max - z_min) / (q_max - q_min)
                # print('left tail under-estimated & right tail over-estimated')
                # print(c)
                return (quantile - q_min) * c + z_min
            elif q_max <= z_min or z_max <= q_min:
                ''' support is not overwrapped '''
                z = z_max - z_min
                q = q_max - q_min
                if q >= z:
                    c = z / q
                else:
                    c = 1.
                # print('support is not overwrapped')
                # print(z)
                # print(q)
                # print(c)
                return (quantile - q_min) * c + z_min
            else:
                ''' this should not be the case '''
                raise ValueError



def compute_quantile_Huber_loss(
        td_errors: torch.tensor,
        tau: torch.tensor,
        weight: torch.tensor,
        kappa: float = 1.0,
        q_net_head: str = 'multiple'
    ) -> torch.tensor:

    assert not tau.requires_grad
    batch_size, N, N_prime = td_errors.shape
    if q_net_head == 'multiple':
        tau_expanded = tau.unsqueeze(-1)
        weight_expanded = weight.unsqueeze(0).unsqueeze(1)
    elif q_net_head == 'single':
        tau_expanded = tau.unsqueeze(-1)
        weight_expanded = weight.unsqueeze(1)
        ''' TODO check here '''
    else:
        raise ValueError
    # print(f'tau.shape={tau.shape}')
    # print(f'tau_expanded.shape={tau_expanded.shape}')
    # print(f'weight.shape={weight.shape}')
    # print(f'weight_expanded.shape={weight_expanded.shape}')

    if kappa == 0.:
        quantile_Huber_loss = torch.maximum(tau_expanded * td_errors, (tau_expanded - 1) * td_errors)

    elif kappa > 0.:
        def compute_Huber_loss(u, k):
            ''' Eq. (9) '''
            return torch.where(
                u.abs() <= k,
                0.5 * u.pow(2),
                k * (u.abs() - 0.5 * k)
            )

        Huber_loss = compute_Huber_loss(td_errors, kappa)
        assert Huber_loss.shape == (batch_size, N, N_prime)

        ''' Eq. (10) '''
        quantile_Huber_loss = torch.abs(
            tau_expanded - (td_errors < 0).float()
            ) * Huber_loss / kappa

    else:
        ''' deprecated; this should not be a case. '''
        k = - kappa

        quantile_Huber_loss = torch.empty((batch_size, N, N_prime))
        tau_ = tau_expanded

        cond_l = td_errors > tau_ * kappa
        quantile_Huber_loss[cond_l] = (tau_ * td_errors.abs() - 0.5 * kappa * tau_.pow(2))[cond_l]
        cond_c = ((tau_ - 1.) * kappa <= td_errors) & (td_errors <= tau_ * kappa)
        quantile_Huber_loss[cond_c] = td_errors[cond_c].pow(2) / (2 * kappa)
        cond_r = td_errors < (tau_ - 1.) * kappa
        quantile_Huber_loss[cond_r] = ((tau_ - 1.) * td_errors.abs() - 0.5 * kappa * (tau_ - 1.).pow(2))[cond_r]

        assert torch.all(torch.eq(cond_l.float() + cond_c.float() + cond_r.float(), torch.ones((batch_size, N, N_prime))))


    assert quantile_Huber_loss.shape == (batch_size, N, N_prime)

    """
        \sum_{i=1}^{N} E_{j} [
            \rho_{\hat{\tau}_{i}}^{\kappa} (T \theta_{j} - \theta_{i}(x, a))
        ]
    """
    # quantile_Huber_loss = (quantile_Huber_loss * weight.unsqueeze(0).unsqueeze(1)).sum(dim=2).sum(dim=1, keepdim=True)
    quantile_Huber_loss = (quantile_Huber_loss * weight_expanded).sum(dim=2).mean(dim=1, keepdim=True)
    assert quantile_Huber_loss.shape == (batch_size, 1)

    return quantile_Huber_loss.mean()



class MLP(nn.Module):
    def __init__(self, n_units, activation, output_activation=nn.Identity, layer_norm=False):
        super(MLP, self).__init__()
        self.layers = nn.ModuleList()
        self.normalizations = nn.ModuleList()
        self.activations = nn.ModuleList()
        for l in range(len(n_units) - 1):
            act = activation if l < len(n_units)-2 else output_activation
            norm = nn.LayerNorm(n_units[l+1]) if layer_norm and l < len(n_units)-2 else nn.Identity()
            self.layers.append(nn.Linear(n_units[l], n_units[l+1]))
            self.normalizations.append(norm)
            self.activations.append(act())

    def forward(self, x):
        out = x
        for l in range(len(self.layers)):
            out = self.layers[l](out)
            out = self.normalizations[l](out)
            out = self.activations[l](out)
        return out

# def mlp(sizes, activation, output_activation=nn.Identity):
#     layers = []
#     for j in range(len(sizes)-1):
#         act = activation if j < len(sizes)-2 else output_activation
#         layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
#     return nn.Sequential(*layers)

class ReplayBuffer:
    '''
    A simple FIFO experience replay buffer
    '''
    def __init__(self, ob_space, ac_space, buffer_size) -> None:

        # if isinstance(ob_space, spaces.Discrete):
        #     dim_ob = 1
        # else:
        dim_ob = ob_space.shape
        # print(f'{ob_space.shape=}')

        # if isinstance(ac_space, spaces.Discrete):
        #     dim_ac = 1
        # else:
        dim_ac = ac_space.shape

        self.ob_buf = np.zeros(self._combined_shape(buffer_size, dim_ob), dtype=np.float32)
        self.ob2_buf = np.zeros(self._combined_shape(buffer_size, dim_ob), dtype=np.float32)
        self.ac_buf = np.zeros(self._combined_shape(buffer_size, dim_ac), dtype=np.float32)
        self.rew_buf = np.zeros(self._combined_shape(buffer_size, 1), dtype=np.float32)
        self.done_buf = np.zeros(self._combined_shape(buffer_size, 1), dtype=np.int32)
        self.ptr = 0
        self.buffer_size = 0
        self.max_buffer_size = buffer_size
        self.reward_min = np.inf
        self.reward_max = -np.inf

    def _combined_shape(self, length: int, shape: tuple = None) -> ty.Tuple[int, list]:
        if shape is None:
            return (length,)
        return (length, shape) if np.isscalar(shape) else (length, *shape)

    def store(self, ob, ac, rew, next_ob, done) -> None:
        # print(f'{ob.shape=}')
        self.ob_buf[self.ptr] = ob
        self.ob2_buf[self.ptr] = next_ob
        self.ac_buf[self.ptr] = ac
        self.rew_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr+1) % self.max_buffer_size
        self.buffer_size = min(self.buffer_size+1, self.max_buffer_size)

    def sample(self, batch_size: int) -> dict:
        idxs = np.random.randint(0, self.buffer_size, size=batch_size)
        # print(f'buf idxs = {idxs}')
        batch = dict(ob=self.ob_buf[idxs],
                     ob2=self.ob2_buf[idxs],
                     ac=self.ac_buf[idxs],
                     rew=self.rew_buf[idxs],
                     done=self.done_buf[idxs])
        return {k: torch.as_tensor(v, dtype=torch.float32) for k,v in batch.items()}

    def sample_uniformly(self) -> dict:
        batch = dict(ob=self.ob_buf[0:self.buffer_size],
                     ob2=self.ob2_buf[0:self.buffer_size],
                     ac=self.ac_buf[0:self.buffer_size],
                     rew=self.rew_buf[0:self.buffer_size],
                     done=self.done_buf[0:self.buffer_size])
        return {k: torch.as_tensor(v, dtype=torch.float32) for k,v in batch.items()}

    def get_reward_bound(self) -> ty.Tuple[float, float]:
        reward_min = self.rew_buf.min()
        reward_max = self.rew_buf.max()
        if self.reward_min > reward_min:
            self.reward_min = reward_min
        if self.reward_max < reward_max:
            self.reward_max = reward_max
        return self.reward_min, self.reward_max



def linear_interpolation(l: float, r: float, weight: float) -> float:
    return l + weight * (r - l)


class ConstantScheduler(object):
    def __init__(self, value: float):
        self._v = value

    def value(self, t: int) -> float:
        return self._v


class ExponentialScheduler(object):
    def __init__(
        self, value: float, min_value: float, discount: float
    ):
        self._v = value
        self._min_v = min_value
        self._discount = discount

    def value(self, t: int) -> float:
        if self._v > self._min_v:
            self._v = self._v * self._discount
        if self._v < self._min_v:
            self._v = self._min_v
        return self._v
        # weight_of_ood_l2 = self.w_schedule.value(self._num_q_update_steps)
        # if self._num_q_update_steps > 50000 \
        #     and weight_of_ood_l2 > self.min_weight_ood \
        #     and np.mean(self.previous_Q1) < np.mean(np.array(self.previous_Q2)[:-5]):
    	# 	self.w_schedule = ConstantSchedule(weight_of_ood_l2/self.decay_factor)


class PiecewiseLinearScheduler(object):
    def __init__(
        self,
        endpoints: ty.List[ty.Tuple[int, float]],
        interpolation: ty.Callable = linear_interpolation,
        outside_value: float = None,
    ):
        """ from OpenAI baselines """
        idxes = [e[0] for e in endpoints]
        assert idxes == sorted(idxes)
        self._interpolation = interpolation
        self._outside_value = outside_value
        self._endpoints = endpoints

    def value(self, t: int) -> float:
        for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]):
            if l_t <= t and t < r_t:
                weight = float(t - l_t) / (r_t - l_t)
                return self._interpolation(l, r, weight)

        # t does not belong to any of the pieces, so doom.
        assert self._outside_value is not None
        return self._outside_value
