import sys
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/")
import global_variables
#########################################################################################

import os
import torch
import itertools
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

from utils_sinkhorn import *

class EB(nn.Module):
    """ Pairwise interactions block.
    """

    def __init__(self, name, d_feat, d_Mfeat, d_out, N, nmoments, previous_nmoments, tensor_order, position, with_label = False):
        super(EB, self).__init__()
        self.name = name
        self.d_feat = d_feat
        self.d_Mfeat = d_Mfeat  # feature moments -- u
        self.d_out = d_out  # output feature dimension
        self.N = N
        self.nmoments = nmoments
        self.tensor_order = tensor_order
        self.with_label = with_label
        self.position = position

        if self.position != -1:
            self.meas_feat = nn.Linear(tensor_order, self.d_Mfeat)
            self.bn_moment = nn.BatchNorm1d(self.d_Mfeat, momentum=0.1, track_running_stats=False)

        self.vect_feat = nn.Linear(tensor_order, self.d_Mfeat)


        if self.with_label:
            self.d_Mlab = 1  # feature moments -- u
            if self.position != -1:
                self.meas_x = nn.Linear(self.d_Mfeat + self.d_Mlab, self.d_out)
            self.vect_x = nn.Linear(self.d_Mfeat + self.d_Mlab, self.nmoments)
            self.vect_mlab_pos = torch.nn.Parameter(torch.randn((1)))
            self.vect_mlab_neg = torch.nn.Parameter(torch.randn((1)))
            self.meas_mlab_pos = torch.nn.Parameter(torch.randn((1)))
            self.meas_mlab_neg = torch.nn.Parameter(torch.randn((1)))
        else:
            if self.position != -1:
                self.meas_x = nn.Linear(self.d_Mfeat, self.d_out)
            self.vect_x = nn.Linear(self.d_Mfeat, self.nmoments)
            self.vect_mlab_pos = None
            self.vect_mlab_neg = None
            self.meas_mlab_pos = None
            self.meas_mlab_neg = None

        # self.bn_meas = nn.BatchNorm1d(self.d_out, momentum=0.1)
        # self.bn_v_moment = nn.BatchNorm1d(self.nmoments, momentum=0.1)

        self.first = True if self.position == 0 else False
        self.last = True if self.position == -1 else False

        if self.first:
            self.meas_z = nn.Linear(1, self.d_out)
            self.vect_z = nn.Linear(1, self.nmoments)
        else:
            self.meas_z = nn.Linear(previous_nmoments, self.d_out)
            self.vect_z = nn.Linear(previous_nmoments, self.nmoments)

        self.bn_meas_z = nn.BatchNorm1d(self.d_out, momentum=0.1)
        self.bn_vect_z = nn.BatchNorm1d(self.nmoments, momentum=0.1)
        self.bn_v_moment = nn.BatchNorm1d(self.d_Mfeat, momentum=0.1)
        self.bn_meas_x = nn.BatchNorm1d(self.d_out, momentum=0.1)
        self.bn_vect_x = nn.BatchNorm1d(self.nmoments, momentum=0.1)
        self.bn_meas_feat = nn.BatchNorm1d(self.d_Mfeat, momentum=0.1)
        self.bn_vect_feat = nn.BatchNorm1d(self.d_Mfeat, momentum=0.1)

    def forward(self, x, z, labels=None):
        batch_size, npoints, d_feat = x.size()
        x = x.view(batch_size, npoints * d_feat)
        N = self.N

        # compute pairwise distances for nearest neighbor search.
        distances = torch.sqrt(batch_Lpcost(x, x, 2, d_feat))

        # select N nonzero interactions of interest per point.
        _, idx = torch.topk(distances, N, 2, largest=False, sorted=True)
        # global_variables.debug_tensor("{}/distance".format(self.name), distances[torch.randint(16, (1,))[0], :, :])
        distances = None

        # tensorized features of size (batch_size,(N-1)*npoints,2*d_feat)
        moments = None
        v_moments = None

        # Compute tensorization
        try:
            to_select = torch.FloatTensor([[0] + list(l) for l in list(itertools.combinations(range(1, N), self.tensor_order-1))[:(N-1)]]).long()
            x_ = [torch.gather(x.view(batch_size, npoints, d_feat).unsqueeze(1).repeat(1, npoints, 1, 1), 2, idx[:, :, t].unsqueeze(3).repeat(1, 1, 1, d_feat)).unsqueeze(2) for t in to_select]
            x_ = torch.cat(x_, dim=2)
            x_ = x_.view(batch_size, (N-1) * npoints, self.tensor_order * d_feat)
            x_ = torch.cat(torch.chunk(x_.view(batch_size, self.tensor_order*(N-1) * npoints, d_feat), d_feat,dim=2), 1).squeeze(2).view(batch_size,(N-1) * npoints* d_feat, self.tensor_order)
        except Exception as e:
            raise e

        # global_variables.debug_tensor("{}/x_".format(self.name), x)

        # Compute moments
        if not self.last: # Do not compute m_feat for the last layer
            moments = self.meas_feat(x_) # TODO: add Batchnorm
            # global_variables.debug_tensor("{}/moments_1".format(self.name), moments)
            # moments = self.bn_meas_feat(moments.view(batch_size * d_feat * (N - 1) * npoints, self.d_Mfeat))
            # moments = apply(func=apply_norm_max, M=moments, dim=0)
            # global_variables.debug_tensor("{}/moments_2".format(self.name), moments)
            # moments = F.relu(moments)
            # global_variables.debug_tensor("{}/moments_3".format(self.name), moments)
            moments = torch.mean(torch.stack(torch.chunk(moments.view(batch_size, (N - 1) * npoints * d_feat, self.d_Mfeat), d_feat, dim=1), 1), 1)
            # global_variables.debug_tensor("{}/moments_4".format(self.name), moments)
            # moments = self.bn_moment(moments.view(batch_size * (N - 1) * npoints, self.d_Mfeat))
            # global_variables.debug_tensor("{}/moments_5".format(self.name), moments)
            # moments = F.relu(moments).view(batch_size, (N - 1) * npoints, self.d_Mfeat)
            # global_variables.debug_tensor("{}/moments_6".format(self.name), moments)

            # moments = torch.clamp(moments, 0, 1)
            # global_variables.debug_tensor("{}/moments".format(self.name), moments)
        else:
            moments = None

        v_moments = self.vect_feat(x_)
        # v_moments = self.bn_vect_feat(v_moments.view(batch_size * d_feat * (N - 1) * npoints, self.d_Mfeat))
        # v_moments = F.relu(v_moments)
        v_moments = torch.mean(torch.stack(torch.chunk(v_moments.view(batch_size, (N - 1) * npoints * d_feat, self.d_Mfeat), d_feat, dim=1), 1), 1)
        # v_moments = self.bn_v_moment(v_moments.view(batch_size * (N - 1) * npoints, self.d_Mfeat))
        # v_moments = F.relu(v_moments).view(batch_size, (N - 1) * npoints, self.d_Mfeat)
        # v_moments = torch.clamp(v_moments, 0, 1)

        # global_variables.debug_tensor("{}/v_moments".format(self.name), v_moments)

        x_ = None

        if labels is not None:
            labels = batch_index_select_NN(labels.view(batch_size, npoints, 1), idx)

            m_lab = torch.mean((labels[:, :, :1] == labels[:, :, 1:]).float(), 2) * self.meas_mlab_pos + \
                                torch.mean((labels[:, :, :1] != labels[:, :, 1:]).float(), 2) * self.meas_mlab_neg
            m_lab = m_lab.unsqueeze(2)
            v_lab = torch.mean((labels[:, :, :1] == labels[:, :, 1:]).float(), 2) * self.vect_mlab_pos + \
                                torch.mean((labels[:, :, :1] != labels[:, :, 1:]).float(), 2) * self.vect_mlab_neg
            v_lab = v_lab.unsqueeze(2)

            m_lab = F.relu(m_lab)
            v_lab = F.relu(v_lab)


            moments = torch.cat([moments, m_lab], 2) # size (batch_size,(N-1)*npoints,Mfeat+Mlab)
            v_moments = torch.cat([v_moments, v_lab], 2) # size (batch_size,(N-1)*npoints,Mfeat+Mlab)

            m_lab, v_lab = None, None

        # Compute output measure
        if self.last:
            x_new = None
        else:
            x_new = self.meas_x(moments)
            moments = None
            # x_new = F.relu(x_new)
            # x_new += F.relu(self.bn_meas_z(self.meas_z(z)).unsqueeze(1))
            # average over neighbors to create new measure of size (batch_size,npoints,d_out)
            # x_new = self.bn_meas_z(x_new).view(batch_size, npoints, self.d_out)
            x_new = torch.mean(x_new.view(batch_size, npoints, N-1, self.d_out), 2).view(batch_size, npoints, self.d_out)
            # x_new = apply(func=self.bn_meas_x, M=x_new, dim=0)
            # x_new = torch.clamp(x_new, 0, 1)
            # global_variables.debug_tensor("{}/x_new".format(self.name), x_new)

        # Compute output deterministic vector
        z_new = self.vect_x(v_moments)
        v_moments = None
        # z_new = F.relu(z_new)

        z_new += self.vect_z(z).unsqueeze(1)
        z_new = torch.mean(z_new, 1)
        # z_new = F.relu(self.bn_vect_x(z_new))
        # z_new = torch.clamp(z_new, 0, 1)
        # global_variables.debug_tensor("{}/z_new".format(self.name), z_new)

        return x_new, z_new
