import torch
from torch import nn
from torch.nn import functional as F

from modules.common import ScalarVector, safe_norm


def _get_scalar_activation(name):
    if name is None:
        return nn.Identity()
    elif name == 'sigmoid':
        return torch.sigmoid
    elif name == 'relu':
        return torch.relu
    else:
        return getattr(F, name)


def _get_vector_activation(cfg):
    if cfg is None:
        return nn.Identity()
    elif cfg[0] == 'scale':
        return VectorScaling(_get_scalar_activation(cfg[1]))
    elif cfg[0] == 'project':
        return VectorProjection(cfg[1])
    else:
        raise ValueError('Unknown vector activation class: %s' % cfg[0])


class VectorProjection(nn.Module):

    def __init__(self, n_dims):
        super().__init__()
        self.vecs = nn.Parameter(torch.randn([n_dims, 3]), requires_grad=True)
        nn.utils.weight_norm(self, name='vecs', dim=1)

    def forward(self, x):
        dot_prod = (self.vecs_v * x).sum(dim=-1, keepdim=True)  # (*, n_dims, 1)
        x_proj = x - dot_prod * self.vecs_v  # (*, n_dims, 3)
        out = torch.where(dot_prod >= 0, x, x_proj)
        return out


class VectorScaling(nn.Module):

    def __init__(self, func=torch.sigmoid):
        super().__init__()
        self.func = func

    def forward(self, x):
        s = self.func(safe_norm(x, dim=-1, keepdim=True))
        return s * x


class SVActivation(nn.Module):

    def __init__(self, s_act, v_act):
        super().__init__()
        self.s_act = s_act
        self.v_act = v_act

    @classmethod
    def from_args(cls, scalar_act, vector_act):
        s_act = _get_scalar_activation(scalar_act)
        v_act = _get_vector_activation(vector_act)
        return cls(s_act, v_act)

    def forward(self, x: ScalarVector) -> ScalarVector:
        return ScalarVector(
            s=self.s_act(x.s) if x.s is not None else None,
            v=self.v_act(x.v) if x.v is not None else None,
        )


###################################################################################
# other act functions
###################################################################################

class SVGate(nn.Module):
    def __init__(self, in_dims, out_dims, h_dim=None):
        super(SVGate, self).__init__()
        self.si, self.vi = in_dims
        self.so, self.vo = out_dims
        self.h_dim = h_dim if h_dim is not None else max(self.vi, self.vo)
        self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
        self.ws = nn.Linear(self.h_dim + self.si, self.so)
        if self.vo:
            self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
            self.wsv = nn.Linear(self.so, self.vo)

    def forward(self, x):
        s, v = x.s, x.v
        v = torch.transpose(v, -1, -2)  # [...,3,vi]
        vh = self.wh(v)  # [...,3,h_dim]
        vn = safe_norm(vh, dim=-2)  # [...,h_dim]
        s = F.relu(self.ws(torch.cat([s, vn], -1)))  # [...,si+h_dim]->[...,so]
        if self.vo:
            v = self.wv(vh)  # [...,3,vo]
            v = torch.transpose(v, -1, -2)  # [...,vo,3]
            gate = self.wsv(torch.sigmoid(s))  # [...,vo]
            v = v * torch.sigmoid(gate).unsqueeze(-1)  # [...,vo,3]
        return ScalarVector(s=s, v=v) if self.vo else s


class SVReLU(nn.Module):
    def __init__(self, vi, negative_slope=0.2):
        super(SVReLU, self).__init__()
        self.v2q = nn.Linear(vi, vi, bias=False)
        self.v2k = nn.Linear(vi, vi, bias=False)
        self.negative_slope = negative_slope

    def forward(self, x):
        s, v = x.s, x.v  # [N,si],[N,vi,3]
        s_out = F.leaky_relu(s, negative_slope=self.negative_slope)

        q = self.v2q(v.transpose(1, -1)).transpose(1, -1)  # [N,vi,3]
        k = self.v2k(v.transpose(1, -1)).transpose(1, -1)  # [N,vi,3]
        dotprod = torch.clamp((q * k).sum(2, keepdim=True), min=1e-8)  # [N,vi,1]
        mask = (dotprod >= 0).float()
        k_norm_sq = safe_norm(k, keepdim=True, sqrt=False)
        v_out = self.negative_slope * q + (1 - self.negative_slope) * (
                mask * q + (1 - mask) * (q - (dotprod / (k_norm_sq + 1e-6)) * k))
        return ScalarVector(s=s_out, v=v_out)
