import torch
import torch.nn as nn



class SoundNeRirF(nn.Module):
    def __init__(self, posembed_size=64, RIR_len = 257):
        super(SoundNeRirF, self).__init__()
        self.poseembed_size = posembed_size
        self.RIR_len = RIR_len
        self.act_fn = torch.nn.ReLU()
        self.build_startpos_RIR_MLP()
        self.build_targetpos_RIR_MLP()
        self.build_layerwise_commu()
        self.build_timeshift_MPL()
        self.time_idx = nn.Parameter(data=torch.arange(start=0,end=self.RIR_len,dtype=torch.float32),
                                     requires_grad=False)

    def posenc(self, euc_loc):
        '''
        Position Encoding of receiver euclidean location
        :param euc_loc: [N, x, y, z], the x/y/z are in [-1,1] range
        :return: [N, 3*embedingsize+3]
        '''
        rets = list()
        for i in range(self.poseembed_size):
            for fn in [torch.sin, torch.cos]:
                rets.append(fn(2. ** i * euc_loc))

        return torch.cat(rets, dim=-1)

    def build_timeshift_MPL(self):
        self.time_shift_startpos_param = torch.nn.Linear(in_features=512,
                                                         out_features=1,
                                                         bias=True)

        self.time_shift_targetpos_param = torch.nn.Linear(in_features=512,
                                                          out_features=1,
                                                          bias=True)

    def build_startpos_RIR_MLP(self):
        '''
        build up a six layer MLPs to infer an RIR
        :return:
        '''
        self.startpos_MLP1 = torch.nn.Sequential(torch.nn.Linear(in_features=self.poseembed_size*2*3,
                                                                 out_features=512,bias=True),
                                                 torch.nn.BatchNorm1d(num_features=3))
        self.startpos_MLP2 = torch.nn.Sequential(torch.nn.Linear(in_features=512,
                                                                 out_features=512,bias=True),
                                                 torch.nn.BatchNorm1d(num_features=3))
        self.startpos_MLP3 = torch.nn.Sequential(torch.nn.Linear(in_features=512,
                                                                 out_features=512,bias=True),
                                                 torch.nn.BatchNorm1d(num_features=3))
        self.startpos_MLP4 = torch.nn.Sequential(torch.nn.Linear(in_features=512,
                                                                 out_features=512,bias=True),
                                                 torch.nn.BatchNorm1d(num_features=3))
        self.startpos_MLP5 = torch.nn.Sequential(torch.nn.Linear(in_features=512,
                                                                 out_features=512,bias=True),
                                                 torch.nn.BatchNorm1d(num_features=3))
        self.startpos_MLP6 = torch.nn.Sequential(torch.nn.Linear(in_features=512,
                                                                 out_features=512,bias=True),
                                                 torch.nn.BatchNorm1d(num_features=3))
        self.startpos_RIR = torch.nn.Sequential(torch.nn.Linear(in_features=512,
                                                                out_features=self.RIR_len,
                                                                bias=True),
                                                torch.nn.Tanh())

    def build_targetpos_RIR_MLP(self):
        '''
        build up a six layer MLPs to infer an RIR
        :return:
        '''
        self.targetpos_MLP1 = torch.nn.Sequential(torch.nn.Linear(in_features=self.poseembed_size*2*3,
                                                                  out_features=512,
                                                                  bias=True),
                                                  torch.nn.BatchNorm1d(num_features=3))
        self.targetpos_MLP2 = torch.nn.Sequential(torch.nn.Linear(in_features=512,
                                                                  out_features=512,
                                                                  bias=True),
                                                  torch.nn.BatchNorm1d(num_features=3))
        self.targetpos_MLP3 = torch.nn.Sequential(torch.nn.Linear(in_features=512,
                                                                  out_features=512,
                                                                  bias=True),
                                                  torch.nn.BatchNorm1d(num_features=3))
        self.targetpos_MLP4 = torch.nn.Sequential(torch.nn.Linear(in_features=512,
                                                                  out_features=512,
                                                                  bias=True),
                                                  torch.nn.BatchNorm1d(num_features=3))
        self.targetpos_MLP5 = torch.nn.Sequential(torch.nn.Linear(in_features=512,
                                                                  out_features=512,
                                                                  bias=True),
                                                  torch.nn.BatchNorm1d(num_features=3))
        self.targetpos_MLP6 = torch.nn.Sequential(torch.nn.Linear(in_features=512,
                                                                  out_features=512,
                                                                  bias=True),
                                                  torch.nn.BatchNorm1d(num_features=3))
        self.targetpos_RIR = torch.nn.Sequential(torch.nn.Linear(in_features=512,
                                                                 out_features=self.RIR_len,
                                                                 bias=True),
                                                 torch.nn.Tanh())

    def build_layerwise_commu(self):
        self.stitch = nn.ParameterList([
            nn.Parameter(torch.FloatTensor(512, 2, 2).uniform_(0.1, 0.9), requires_grad=True),
            nn.Parameter(torch.FloatTensor(512, 2, 2).uniform_(0.1, 0.9), requires_grad=True),
            nn.Parameter(torch.FloatTensor(512, 2, 2).uniform_(0.1, 0.9), requires_grad=True),
            nn.Parameter(torch.FloatTensor(512, 2, 2).uniform_(0.1, 0.9), requires_grad=True),
            nn.Parameter(torch.FloatTensor(512, 2, 2).uniform_(0.1, 0.9), requires_grad=True),
        ])

    def recode_micpos(self, input_pos):
        '''
        The input mic position is simply a 3D euclidean position [B, 3], each row is [x, y, z]
        :param input_pos: [B, 3], float32
        :return: [B, 3, 3], where the second 3 indicates three encoding
        '''
        output_pos = [input_pos]
        output_pos.append(input_pos*2)
        output_pos.append(input_pos*3)

        output_pos = torch.stack(output_pos, dim=1)

        return output_pos

    def reorganize_timeshift(self, time_shift):
        '''
        the time shift is of shape [B, 3], we have to guarantee the first [B,0] is the smallest, then the second channel
        [B, 1], then the third channel [B, 2]
        :param time_shift: input time shift [B, 3]
        :return: reorganzed time shift
        '''
        time_shift, _ = torch.sort(time_shift, dim=1, descending=False)

        return time_shift

    def add_monotonicity_causality(self, input_RIR, do_inverse = True, time_shift = None ):
        '''
        Given the input RIR, which lies in between (-1, 1), we explicitly add monoticity and causality to the raw
        RIR to the real RIR
        :param input_RIR: [B, 3, RIR_len]
        :param do_inverse: boolean, if need to reverse-back RIR. In our case, we have to inverse the RIR, for forward-
            RIR, we don't need the inverse operation
        :param time_shift: the learned time shift parameter, [B, 3]
        :return: mature RIR with monotonicity and causality
        '''
        negative_idx = input_RIR < 0.
        positive_idx = input_RIR >= 0.
        negative_idx = negative_idx.to(torch.float32)*-1.
        positive_idx = positive_idx.to(torch.float32)
        pos_neg_idx = negative_idx + positive_idx
        input_RIR_abs = torch.abs(input_RIR)
        input_RIR_abs = torch.cummax(input_RIR_abs, dim=2)[0]
        input_RIR = torch.mul(pos_neg_idx, input_RIR_abs)


        #first apply time shift parameter
        time_shift = self.reorganize_timeshift(time_shift)
        batch_size = time_shift.shape[0]
        time_shift = torch.unsqueeze(time_shift, dim=-1)
        time_shift = torch.tile(time_shift, dims=(1,1,self.RIR_len))

        time_idx = torch.unsqueeze(self.time_idx, dim=0)
        time_idx = torch.unsqueeze(time_idx, dim=0)
        time_idx = torch.tile(time_idx, dims=[batch_size,3,1])

        time_idx = time_idx - time_shift

        sigmoid_truncat = torch.sigmoid( 5*time_idx )
        sigmoid_truncat = torch.flip(sigmoid_truncat, dims=[2])
        input_RIR = torch.mul( sigmoid_truncat, input_RIR )

        if do_inverse:
            input_RIR = torch.flip( input_RIR, dims=[2] )

        #element-wise add 3 sub-RIR to form the final RIR
        input_RIR = torch.sum(input_RIR, dim=1, keepdim=False)

        return input_RIR

    def compute_RIR(self, start_pos_embed, target_pos_embed):
        '''
        given the start position embedding, and target position embedding, we explicitly encode the
        inverse-RIR and forward-RIR
        :param start_pos_embed: start position embedding
        :param target_pos_embed: end position embedding
        :return: inverse-RIR, forward-RIR
        '''
        mlp_embed_start = self.startpos_MLP1(start_pos_embed)
        mlp_embed_start = self.act_fn(mlp_embed_start)

        mlp_embed_target = self.targetpos_MLP1(target_pos_embed)
        mlp_embed_target = self.act_fn(mlp_embed_target)

        mlp_embed_start_c1 = torch.einsum('f,bpf->bpf', self.stitch[0][:,0,0], mlp_embed_start) + \
            torch.einsum('f,bpf->bpf', self.stitch[0][:,0,1], mlp_embed_target)

        mlp_embed_target_c1 = torch.einsum('f,bpf->bpf', self.stitch[0][:,1,0], mlp_embed_start) + \
            torch.einsum('f,bpf->bpf', self.stitch[0][:,1,1], mlp_embed_target)

        mlp_embed_start_c1 = self.startpos_MLP2(mlp_embed_start_c1)
        mlp_embed_start_c1 = self.act_fn(mlp_embed_start_c1)

        mlp_embed_target_c1 = self.targetpos_MLP2(mlp_embed_target_c1)
        mlp_embed_target_c1 = self.act_fn(mlp_embed_target_c1)

        mlp_embed_start_c2 = torch.einsum('f,bpf->bpf', self.stitch[1][:,0,0], mlp_embed_start_c1) + \
            torch.einsum('f,bpf->bpf', self.stitch[1][:,0,1], mlp_embed_target_c1)

        mlp_embed_target_c2 = torch.einsum('f,bpf->bpf', self.stitch[1][:,1,0], mlp_embed_start_c1) + \
            torch.einsum('f,bpf->bpf', self.stitch[1][:,1,1], mlp_embed_target_c1)

        mlp_embed_start_c2 = self.startpos_MLP3(mlp_embed_start_c2)
        mlp_embed_start_c2 = self.act_fn(mlp_embed_start_c2)

        mlp_embed_target_c2 = self.targetpos_MLP3(mlp_embed_target_c2)
        mlp_embed_target_c2 = self.act_fn(mlp_embed_target_c2)

        mlp_embed_start_c3 = torch.einsum('f,bpf->bpf', self.stitch[2][:,0,0], mlp_embed_start_c2) + \
            torch.einsum('f,bpf->bpf', self.stitch[2][:,0,1], mlp_embed_target_c2)

        mlp_embed_target_c3 = torch.einsum('f,bpf->bpf', self.stitch[2][:,1,0], mlp_embed_start_c2) + \
            torch.einsum('f,bpf->bpf', self.stitch[2][:,1,1], mlp_embed_target_c2)

        mlp_embed_start_c3 = self.startpos_MLP4(mlp_embed_start_c3)
        mlp_embed_start_c3 = self.act_fn(mlp_embed_start_c3)

        mlp_embed_target_c3 = self.targetpos_MLP4(mlp_embed_target_c3)
        mlp_embed_target_c3 = self.act_fn(mlp_embed_target_c3)

        mlp_embed_start_c4 = torch.einsum('f,bpf->bpf', self.stitch[3][:,0,0], mlp_embed_start_c3) + \
            torch.einsum('f,bpf->bpf', self.stitch[3][:,0,1], mlp_embed_target_c3)

        mlp_embed_target_c4 = torch.einsum('f,bpf->bpf', self.stitch[3][:,1,0], mlp_embed_start_c3) + \
            torch.einsum('f,bpf->bpf', self.stitch[3][:,1,1], mlp_embed_target_c3)

        mlp_embed_start_c4 = self.startpos_MLP5(mlp_embed_start_c4)
        mlp_embed_start_c4 = self.act_fn(mlp_embed_start_c4)

        mlp_embed_target_c4 = self.targetpos_MLP5(mlp_embed_target_c4)
        mlp_embed_target_c4 = self.act_fn(mlp_embed_target_c4)

        mlp_embed_start_c5 = torch.einsum('f,bpf->bpf', self.stitch[4][:,0,0], mlp_embed_start_c4) + \
            torch.einsum('f,bpf->bpf', self.stitch[4][:,0,1], mlp_embed_target_c3)

        mlp_embed_target_c5 = torch.einsum('f,bpf->bpf', self.stitch[4][:,1,0], mlp_embed_start_c4) + \
            torch.einsum('f,bpf->bpf', self.stitch[4][:,1,1], mlp_embed_target_c4)

        mlp_embed_start_c5 = self.startpos_MLP6(mlp_embed_start_c5)
        mlp_embed_start_c5 = self.act_fn(mlp_embed_start_c5)

        mlp_embed_target_c5 = self.targetpos_MLP6(mlp_embed_target_c5)
        mlp_embed_target_c5 = self.act_fn(mlp_embed_target_c5)

        inverse_RIR = self.startpos_RIR(mlp_embed_start_c5)
        forward_RIR = self.targetpos_RIR(mlp_embed_target_c5)

        inverse_timeshift = self.time_shift_startpos_param(mlp_embed_start_c5)
        forward_timeshift = self.time_shift_targetpos_param(mlp_embed_target_c5)

        inverse_timeshift = torch.squeeze(inverse_timeshift)
        forward_timeshift = torch.squeeze(forward_timeshift)
        inverse_timeshift = torch.clamp(inverse_timeshift, min=0., max=self.RIR_len-1.)
        forward_timeshift = torch.clamp(forward_timeshift, min=0., max=self.RIR_len-1.)


        return inverse_RIR, forward_RIR, inverse_timeshift, forward_timeshift


    def forward(self, start_pos, target_pos):
        '''
        :param start_pos: [B, 3]
        :param target_pos: [B, 3]
        :return: forward-RIR, backward-RIR
        '''
        start_pos = self.recode_micpos(start_pos) #[B, 3, 3]
        target_pos = self.recode_micpos(target_pos)

        start_pos_embed = self.posenc(start_pos) #[B, 3, 512]
        target_pos_embed = self.posenc(target_pos) #[B, 3, 512]

        inverse_RIR, forward_RIR, inverse_timeshift, forward_timeshift = self.compute_RIR( start_pos_embed=start_pos_embed,
                                                                                           target_pos_embed=target_pos_embed )

        inverse_RIR = self.add_monotonicity_causality( input_RIR=inverse_RIR,
                                                       do_inverse=False,
                                                       time_shift=inverse_timeshift )

        forward_RIR = self.add_monotonicity_causality( input_RIR=forward_RIR,
                                                       do_inverse=True,
                                                       time_shift=forward_timeshift )

        return inverse_RIR, forward_RIR