import sys
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/")
import global_variables
#########################################################################################

import torch
import torch.nn as nn
import torch.nn.functional as F

from block import EB


class DIDA(nn.Module):
    def __init__(self, d_Mfeat, d_Mlab, N, nmoments, d_out,
                 fc_metafeatures, tensorizations, dropout_fc, writer):
        super(DIDA, self).__init__()
        self.N = N
        self.d_out = d_out
        self.d_Mlab = d_Mlab
        self.d_Mfeat = d_Mfeat
        self.nmoments = nmoments
        self.tensorizations = tensorizations
        self.fc_metafeatures = fc_metafeatures
        self.dropout_fc = dropout_fc
        self.writer = writer

        self.previous_nmoments = [_ for _ in self.nmoments]

        for i in range(len(self.nmoments)):
            if i == 0:
                self.previous_nmoments[i] = 1
            else:
                self.previous_nmoments[i] = self.nmoments[(i-1)]

        # SDN
        self.list_module = torch.nn.ModuleList()
        self.list_batch_norm = torch.nn.ModuleList()
        for i, (d_Ml, d_Mf, d_o, n_moment) in enumerate(zip(self.d_Mlab, self.d_Mfeat,
                                                            self.d_out, self.nmoments)):
            self.list_module.append(EB(name="EB_{}".format(i),
                                       d_feat=self.d_out[i-1],
                                       d_Mfeat=d_Mf,
                                       d_out=d_o,
                                       N=self.N,
                                       nmoments=n_moment,
                                       previous_nmoments=self.previous_nmoments[i],
                                       tensor_order=self.tensorizations[i],
                                       position=-1 if ((i == len(self.nmoments) - 1) and (len(self.nmoments) != 1)) else i,
                                       with_label=(i==0)
                                       ))
            self.list_batch_norm.append(nn.BatchNorm1d(n_moment, momentum=0.1))
            # self.list_dropout_z.append(nn.Dropout(0.01))
            # self.list_dropout_x.append(nn.Dropout(0.01))

        # FULLY CONNECTED
        self.list_fc = torch.nn.ModuleList()
        self.list_bn_fc = torch.nn.ModuleList()
        self.list_dropout_fc = torch.nn.ModuleList()

        for i, dim in enumerate(self.fc_metafeatures):
            if i == 0:
                self.list_fc.append(nn.Linear(self.nmoments[-1], dim))
            else:
                self.list_fc.append(nn.Linear(self.fc_metafeatures[i-1], dim))

            if i != len(self.fc_metafeatures) - 1:
                self.list_bn_fc.append(nn.BatchNorm1d(dim))
                self.list_dropout_fc.append(nn.Dropout(self.dropout_fc[i]))

    def forward(self, x, lab, device="cuda"):
        # SDN
        for i, (layer, layer_bn) in enumerate(zip(self.list_module, self.list_batch_norm)):
            if i == 0:
                x, z = layer(x=x, z=torch.zeros(x.size(0), 1).to(device), labels=lab)
                del lab
                # self.writer.add_histogram("SDN/x_{}".format(i), x, global_variables.batch_idx)
            else:
                x, z = layer(x=x, z=z)
            z = layer_bn(z)
            z = F.relu(z)
            # TODO: add batchnorm to x
            # self.writer.add_histogram("SDN/z_{}".format(i), z, global_variables.batch_idx)
            # self.writer.add_scalar("Max/z_{}".format(i), z.max(), global_variables.batch_idx)
            # self.writer.add_scalar("Min/z_{}".format(i), z.min(), global_variables.batch_idx)

        # Fully Connected
        for i, layer in enumerate(self.list_fc):
            z = layer(z)

            if i != len(self.list_fc) - 1:
                z = self.list_bn_fc[i](z)
                z = F.relu(z)
                z = self.list_dropout_fc[i](z)

        return z


class Net(nn.Module):
    def __init__(self, parameters, size_hparams, writer=None):
        super(Net, self).__init__()

        self.N = parameters["N"]
        self.d_out = parameters["d_out"]
        self.d_Mlab = parameters["d_Mlab"]
        self.d_Mfeat = parameters["d_Mfeat"]
        self.nmoments = parameters["nmoments"]
        self.final_fc_output_dim = parameters["final_fc_output_dim"]
        self.fc_metafeatures = parameters["fc_metafeatures"]
        self.tensorizations = parameters["tensorizations"]
        self.dropout_fc = parameters["dropout_fc"]
        self.writer = writer

        self.dida = DIDA(d_Mfeat=self.d_Mfeat,
                                       d_Mlab=self.d_Mlab,
                                       N=30,
                                       nmoments=self.nmoments,
                                       d_out=self.d_out,
                                       fc_metafeatures=self.fc_metafeatures,
                                       tensorizations=self.tensorizations,
                                       dropout_fc=self.dropout_fc,
                                       writer=self.writer)
        self.fc_final = nn.Linear(self.fc_metafeatures[-1] + size_hparams, self.final_fc_output_dim)

    def forward(self, x, lab, params):
        assert x.max().item() <= 1.0 and x.max().item() >= 0.0
        assert params.max().item() <= 1.0 and params.max().item() >= 0.0

        dida_z = self.dida(x, lab)
        z = F.normalize(dida_z, p=2, dim=1)
        global_variables.debug_tensor("dida", z)
        z = torch.cat([z, params], dim=1)
        global_variables.debug_tensor("concat", z)
        z = self.fc_final(z).squeeze(1)
        global_variables.debug_tensor("output", z)

        return z, dida_z
