import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict

class LocalBatchNorm(nn.Module):
    def __init__(self):
        super(LocalBatchNorm, self).__init__()

    def forward(self, X):
        return X

class Block_DSS_v1(nn.Module):
    def __init__(self, dim_output_phi, dim_output_rho, with_label=False):
        super(Block_DSS_v1, self).__init__()
        phi = []
        for i, d in enumerate(dim_output_phi):
            phi.append(("Lin_{}".format(i), nn.Linear(in_features=(1 if i == 0 else dim_output_phi[i-1]),
                                                      out_features=d)))
            phi.append(("bn_{}".format(i), LocalBatchNorm()))
        self.phi = nn.Sequential(OrderedDict(phi))
        self.bn_phi = nn.BatchNorm1d(dim_output_phi[-1])

        # label
        phi_lab = []
        for i, d in enumerate(dim_output_phi):
            phi_lab.append(("Lin_{}".format(i), nn.Linear(1 if i == 0 else dim_output_phi[i-1], d)))
        self.phi_lab = nn.Sequential(OrderedDict(phi_lab))
        self.bn_phi_lab = nn.BatchNorm1d(dim_output_phi[-1])

        rho = []
        for i, d in enumerate(dim_output_rho):
            rho.append(("Lin_{}".format(i), nn.Linear((dim_output_phi[-1] * 2) if i == 0 else dim_output_rho[i-1], d)))
            rho.append(("bn_{}".format(i), LocalBatchNorm()))
        self.rho = nn.Sequential(OrderedDict(rho))
        self.bn_rho = nn.BatchNorm1d(dim_output_rho[-1])


    def forward(self, X, lab=None):
        X = X.mean(1).unsqueeze(2)
        X = self.phi(X)
        X = X.mean(1)
        X = self.bn_phi(X)

        lab = lab.mean(1).unsqueeze(2)
        lab = self.phi_lab(lab)
        lab = lab.mean(1)
        lab = self.bn_phi_lab(lab)
        X = torch.cat([X, lab], dim=1)
        # print(X.max(), X.min())

        return self.rho(X)


class Block_DSS_v2(Block_DSS_v1):
    def __init__(self, dim_output_phi, dim_output_rho, with_label=False):
        super(Block_DSS_v2, self).__init__(dim_output_phi, dim_output_rho, with_label)
        phi = []
        for i, d in enumerate(dim_output_phi):
            phi.append(("Lin_{}".format(i), nn.Linear(in_features=(1 if i == 0 else dim_output_phi[i-1]),
                                                      out_features=d)))
            phi.append(("relu_{}".format(i), nn.ReLU()))
        self.phi = nn.Sequential(OrderedDict(phi))
        self.bn_phi = nn.BatchNorm1d(dim_output_phi[-1])

        # label
        phi_lab = []
        for i, d in enumerate(dim_output_phi):
            phi_lab.append(("Lin_{}".format(i), nn.Linear(1 if i == 0 else dim_output_phi[i-1], d)))
            phi_lab.append(("relu_{}".format(i), nn.ReLU()))
        self.phi_lab = nn.Sequential(OrderedDict(phi_lab))
        # self.phi_lab = nn.Linear(1, 10)
        self.bn_phi_lab = nn.BatchNorm1d(dim_output_phi[-1])

        rho = []
        for i, d in enumerate(dim_output_rho):
            rho.append(("Lin_{}".format(i), nn.Linear((dim_output_phi[-1] * 2) if i == 0 else dim_output_rho[i-1], d)))
            rho.append(("bn_{}".format(i), LocalBatchNorm()))
        self.rho = nn.Sequential(OrderedDict(rho))
        self.bn_rho = nn.BatchNorm1d(dim_output_rho[-1])


class Equivariant_Block(nn.Module):
    def __init__(self):
        super(Equivariant_Block, self).__init__()
        self.lambda_X_1 = torch.nn.Parameter(torch.randn((1)))
        self.gamma_X_1 = torch.nn.Parameter(torch.randn((1)))
        self.lambda_X_2 = torch.nn.Parameter(torch.randn((1)))
        self.gamma_X_2 = torch.nn.Parameter(torch.randn((1)))

    def forward(self, X):
        batch_size, n, d = X.size()
        theta_1 = self.lambda_X_1 * torch.eye(d).to("cuda") + self.gamma_X_1 * torch.ones(d, d).to("cuda")
        theta_2 = self.lambda_X_2 * torch.eye(d).to("cuda") + self.gamma_X_2 * torch.ones(d, d).to("cuda")

        L_1 = torch.matmul(X, theta_1)
        X_subsum = (X.mean(1).unsqueeze(1) - X)
        L_2 = torch.matmul(X_subsum, theta_2)

        return L_1 + L_2


class Block_DSS_v3(Block_DSS_v2):
    def __init__(self, dim_output_phi, dim_output_rho, with_label=False):
        super(Block_DSS_v3, self).__init__(dim_output_phi, dim_output_rho, with_label)
        self.equiv_X = nn.Sequential(
            Equivariant_Block()
        )
        self.equiv_lab = nn.Sequential(
            Equivariant_Block()
        )

    def forward(self, X, lab=None):
        X = self.equiv_X(X)
        lab = self.equiv_lab(lab)
        return super().forward(X, lab)


def get_block_dss(version, dim_output_phi, dim_output_rho, with_label):
    if version == "v1":
        return Block_DSS_v1(dim_output_phi, dim_output_rho, with_label)
    elif version == "v2":
        return Block_DSS_v2(dim_output_phi, dim_output_rho, with_label)
    elif version == "v3":
        return Block_DSS_v3(dim_output_phi, dim_output_rho, with_label)
        # return Block_DSS_v3(dim_output_phi, dim_output_rho, with_label)


class DSS(nn.Module):
    def __init__(self, list_dim_output_phi, list_dim_output_rho, fc_metafeatures, dropout_fc, version_dss_block):
        super(DSS, self).__init__()
        self.list_module = torch.nn.ModuleList()

        # Add DSS block
        self.dss = get_block_dss(version_dss_block, list_dim_output_phi, list_dim_output_rho, with_label=True)

        self.list_fc = torch.nn.ModuleList()
        self.list_bn_fc = torch.nn.ModuleList()
        self.list_dropout_fc = torch.nn.ModuleList()

        for i, dim in enumerate(fc_metafeatures):
            if i == 0:
                self.list_fc.append(nn.Linear(list_dim_output_rho[-1] if not isinstance(list_dim_output_rho, int) else list_dim_output_rho, dim))
            else:
                self.list_fc.append(nn.Linear(fc_metafeatures[i-1], dim))
            self.list_bn_fc.append(nn.BatchNorm1d(dim))
            self.list_dropout_fc.append(nn.Dropout(dropout_fc[i]))

    def forward(self, X, lab):
        # DSS
        z = self.dss(X, lab)
        # X = self.list_module[1](X)

        # Fully Connected
        for i, layer in enumerate(self.list_fc):
            z = layer(z)

            if i != len(self.list_fc) - 1:
                z = self.list_bn_fc[i](z)
                z = F.relu(z)
                z = self.list_dropout_fc[i](z)

        return z
