import torch.nn as nn
import torch.nn.functional as F
import math
import torch
import torch.optim as optim
from torch.nn.parameter import Parameter
from itertools import product
import numpy as np

class COP(nn.Module):

    def __init__(self, nfeat, nhid=256, nlayers=4, device=None, args=None):
        super(COP, self).__init__()
        self.layers = nn.ModuleList([])
        self.layers.append(nn.Linear(nfeat, nhid))
        self.bns = torch.nn.ModuleList()
        self.bns.append(nn.BatchNorm1d(nhid, track_running_stats=False))
        for i in range(nlayers-2):
            self.layers.append(nn.Linear(nhid, nhid))
            self.bns.append(nn.BatchNorm1d(nhid, track_running_stats=False))
        self.layers.append(nn.Linear(nhid, nfeat))
        # self.bns.append(nn.BatchNorm1d(nfeat, track_running_stats=False))

        self.device = device
        self.reset_parameters()

    def forward(self, x):
        norm = torch.norm(x)
        for ix, layer in enumerate(self.layers):
            x = layer(x)
            if ix != len(self.layers) - 1:
                x = self.bns[ix](x)
                x = F.relu(x)
        # x = torch.sigmoid(x)
        # x = x / (1e-15 + x.sum(1, keepdim=True))
        # x = x * norm / torch.norm(x)
        # x = self.bns[ix](x)
        return x

    def reset_parameters(self):
        def weight_reset(m):
            if isinstance(m, nn.Linear):
                m.reset_parameters()
            if isinstance(m, nn.BatchNorm1d):
                m.reset_parameters()
        self.apply(weight_reset)
