import numpy as np

import torch
import torch.nn as nn


class Embed(nn.Module):
    ''' Maps unordered sets of vectors in R^d into vectors in R^m.
        Also supports signed discrete measures over R^d, i.e. unordered sets with accompanying real weights.
        To guarantee that the mapping is injective, use m >= 2nd+1 or m >= 2n(d+1)+1 for sets and measures
        respectively, where n is the maximal size of the set or support of the measure.

        The input X to be embedded is assumed to be of size (<size prefix>, n, d), where size_prefix can be
        any list of dimensions (including an empty list). The accompanying weights W, if provided, should
        be of size (<size prefix>, n). e.g. if size(X) = (nBatches, batchSize, n, d) and 
        size(W) = (nBatches, batchSize, n), then each X(b, t, :, :) is treated as a set of n vectors in
        R^d, with corresponding weights w(b, t, :).

        Embedding of {x1,...,xn} with weights {w1,...,wn} is calculated as sum_i wi*phi(a_t*xi+b_t),
        where a_1,...,am and b_1,...,b_m are, respectively, vectors in R^d and real scalars that are drawn randomly.
        If weights are not given, they are assumed to be 1.

        d: dimension of set elements
        m: embedding dimension
        
        std_proj, assume_std_in: If each input coordinate has mean=0 and std=<assume_std_in>, then each projection
                                 a_t*xi+b_t has mean=zero and std=<std_proj>.
        
        allow_offest: If sets to false, then all components b_t of the projection are set to zero. In most cases
                      this prevents the embedding from being injective.

        activation: Activation function phi() to use for calculating the moments.

        projType: How to generate the random projection parameters a_t, b_t.
                  Currently, only 'gaussian' is implemented.
    '''
    def __init__(self, d, m, std_proj=1, assume_std_in = 1, allow_offset = True, activation=torch.sin, projType='gaussian', dtype=torch.float64, device=None):
        super().__init__()
        #self.flatten = nn.Flatten()

        self.dtype = dtype
        self.device = device

        self.activate = activation
        self.d = d
        self.m = m
        
        self.std_proj = std_proj
        self.assume_std_in = assume_std_in
        self.allow_offset = allow_offset
        self.projType = projType

        # Generate projection operator to be used for the embedding
        projMat, offsetVec = self.get_projection_operator()

        self.projMat = nn.Parameter( projMat, requires_grad=False )
        self.offsetVec = nn.Parameter( offsetVec, requires_grad=False )

        if device is not None:
            self.to(device)


    def replace_activation(activation):
        self.activate = activation


    def get_projection_operator(self, dims_prefix = ()):
        # If dims_prefix = 0, generates and returns a random projection operator (projMat, offsetVec).
        #
        # Otherwise, draws a projection randomly for each input set. e.g. if dims_prefix = (nBatches, batchSize), 
        # then the size of X is assumed to be (nBatches, batchSize, n, d), and a different projection will be
        # calculated for each set (or measure) given by X(b,t,:,:) and w(b,t,:)

        proj_mat_shape = dims_prefix + (self.m, self.d)
        offset_vec_shape = dims_prefix + (self.m, 1)

        if self.projType.lower() == 'gaussian':
            std_a = self.std_proj / (np.sqrt(2*self.d) * self.assume_std_in)
            std_b = self.std_proj / np.sqrt(2) 

            projMat = std_a * torch.randn(size=proj_mat_shape, dtype=self.dtype, device=self.device)
            offsetVec = std_b * torch.randn(size=offset_vec_shape, dtype=self.dtype, device=self.device)

            # Old: Create projection operator using numpy
            #projMat_np   = std_a * np.random.randn(*proj_mat_shape)
            #offsetVec_np = std_b * np.random.randn(*offset_vec_shape)

            #if self.allow_offset == False:
            #    offsetVec_np[:] = 0

            #projMat = torch.tensor(projMat_np, dtype=self.dtype, device=self.device)
            #offsetVec = torch.tensor(offsetVec_np, dtype=self.dtype, device=self.device)

            #print('--- Generated projection operator. Sizes: %s %s' % (projMat.shape, offsetVec.shape))

        elif self.projType.lower() == 'grassmannian':
            raise Exception("Grassmannian projection not implemented yet")

        else:
            raise Exception("Invalid projection type")

        return projMat, offsetVec


    def forward(self, X, W=None, include_PX=False, override_activation = None, projOp = None):
        # X: Tensor of size (<size_prefix>, n, d)
        # W: Tensor of size (<size_prefix>, n, 1) 
        # include_PX: Tells whether to output the projections of individual points.
        #             If set to true, outputx PX of size (<size_prefix>, n, m), with each
        #             coordinate corresponding to a projection of one point.
        # override_activation: Allows providing an alternative activation to the one defined internally.
        # projOp: If set to None or 'innate', uses the instance's innate projection operator.
        #         If set to 'stochastic', calculates a random projection operator for each individual point.
        #         Whenever projOp is not None, the projection operator used is returned with the output.

        #X = self.flatten(X)

        if isinstance(projOp, str) and (projOp not in ('stochastic', 'innate')):
            raise Exception("Invalid projOp")

        projOp_orig = projOp

        if (projOp is None) or (projOp == 'innate'):
            Xm, PX = self.calculate_embedding(X, W, projMat = None, offsetVec = None, activation = override_activation)
        
        else:
            # Stochastic mode. We get a pair of X's and optionally W's, and for each two corresponding samples we calculate a random embedding.

            if projOp == 'stochastic':
                # Generate projection operator
                dims_prefix = X.shape[:-2]
                projMat, offsetVec = self.get_projection_operator(dims_prefix = dims_prefix)
                projOp = (projMat, offsetVec)
            else:
                (projMat, offsetVec) = projOp

            Xm, PX = self.calculate_embedding(X, W, projMat = projMat, offsetVec = offsetVec, activation = override_activation)

        out = [Xm,]

        if include_PX:
            out.append(PX)

        if projOp_orig == 'stochastic':
            out.append(projOp)
        elif projOp_orig == 'innate':
            out.append(None)

        if len(out) == 1:
            return out[0]
        
        return out
                

    def calculate_embedding(self, X, W, projMat = None, offsetVec = None, activation = None):
        # X shape: (<size_prefix>, n, d)
        # W shape: (<size_prefix>, n, 1) 

        # projMat: Either (m,d) or (<size_prefix>, m, d)
        # projVec: Either (m,1) or (<size_prefix>, m, 1)

        # Output:
        # Xm: Embedding of (X,W), of size (<size_prefix>, m)
        # PX: Projections of all pointso of X. Size: (<size_prefix>, n, m)

        # Everything that comes before the (n,d) in x.shape
        dims_prefix = X.shape[:-2]

        # To be used for the offset vector
        ones_prefix = (1,) * len(dims_prefix)

        # Axis number of the set-element index {1,..,n} in each set
        vecidx_axis = len(dims_prefix)

        # Axis number for the ambient space R^d
        ambspace_axis = vecidx_axis + 1        
        
        if projMat is None:
            PX = torch.tensordot(X, self.projMat, dims=((ambspace_axis,),(1,)))
            offsetVec_out_dims = ones_prefix + (1, self.m)
            #print('Shapes: X: %s PX: %s projMat: %s offsetVec: %s' % (X.shape, PX.shape, self.projMat.shape, self.offsetVec.shape))
            PX += self.offsetVec.view(*offsetVec_out_dims)
        else:
            #print('Shapes: X: %s projMat: %s offsetVec: %s' % (X.shape, projMat.shape, offsetVec.shape))
            PX = torch.matmul(projMat, X.transpose(vecidx_axis,ambspace_axis))
            #print('PX: %s' % (PX.shape,))
            PX += offsetVec
            PX = PX.transpose(vecidx_axis,ambspace_axis)

        if activation is None:
            ActPX = self.activate(PX)
        else:
            ActPX = activation(PX)

        if W is None:
            Xm = torch.sum(ActPX, axis=vecidx_axis)
        else:
            Xm = torch.sum(torch.mul(ActPX, W), axis=vecidx_axis)

        #print('Xm: %s' % (Xm.shape,))
        return Xm, PX

