import sys
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/dida")
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/models")
#########################################################################################

import torch
import torch.nn as nn
import global_variables
import torch.nn.functional as F
from dida_network import DIDA


class NetRanking(nn.Module):
    def __init__(self, metafeatures_extractor, parameters, size_hparams, writer=None):
        super(NetRanking, self).__init__()

        self.N = parameters["N"]
        self.d_out = parameters["d_out"]
        self.d_Mlab = parameters["d_Mlab"]
        self.d_Mfeat = parameters["d_Mfeat"]
        self.nmoments = parameters["nmoments"]
        self.final_fc_output_dim = parameters["final_fc_output_dim"]
        self.fc_metafeatures = parameters["fc_metafeatures"]
        self.tensorizations = parameters["tensorizations"]
        self.dropout_fc = parameters["dropout_fc"]
        self.writer = writer

        self.metafeatures_extractor = metafeatures_extractor

        self.bn_dida = nn.BatchNorm1d(self.fc_metafeatures[-1])
        self.fc_final = NN_Surrogate(size_meta_features=self.fc_metafeatures[-1], size_params=size_hparams)

    def forward(self, x, lab, params_1, params_2):
        assert x.max().item() <= 1.0 and x.max().item() >= 0.0
        assert params_1.max().item() <= 1.0 and params_1.max().item() >= 0.0
        assert params_2.max().item() <= 1.0 and params_2.max().item() >= 0.0

        dida_z = self.metafeatures_extractor(x, lab)
        z = F.relu(self.bn_dida(dida_z))
        global_variables.debug_tensor("dida", dida_z)

        z_1 = self.fc_final(z, params_1)
        z_2 = self.fc_final(z, params_2)
        global_variables.debug_tensor("parm_1", z_1)
        global_variables.debug_tensor("parm_2", z_2)

        z = torch.cat([z_1, z_2], dim=1)
        return z, dida_z


class NN_Surrogate(nn.Module):
    def __init__(self, size_meta_features, size_params):
        super(NN_Surrogate, self).__init__()

        self.fc_1 = nn.Linear(size_meta_features + size_params, 64)
        self.fc_2 = nn.Linear(64, 32)
        self.fc_3 = nn.Linear(32, 1)

    def forward(self, mf, hp):
        z = torch.cat([mf, hp], dim=1)
        z = F.relu(self.fc_1(z))
        z = F.relu(self.fc_2(z))
        return self.fc_3(z)

class NN_patch_identification(nn.Module):
    def __init__(self, metafeatures_extractor):
        super(NN_patch_identification, self).__init__()
        self.metafeatures_extractor = metafeatures_extractor

    def forward(self, X1, lab1, X2, lab2):
        mf_1 = self.metafeatures_extractor(X1, lab1)
        if self.metafeatures_extractor.training:
            self.metafeatures_extractor.eval()
            mf_2 = self.metafeatures_extractor(X2, lab2)
            self.metafeatures_extractor.train()
        else:
            mf_2 = self.metafeatures_extractor(X2, lab2)

        diff = torch.abs(mf_1 - mf_2)
        z = torch.exp(- torch.norm(diff, 2, dim=1, keepdim=True))
        return torch.cat((z, 1 - z), 1)
