import torch
import torch.nn as nn
import pdb
import torch.nn.functional as F


class Gate(nn.Module):
    def __init__(self, num_gates,**kwargs):
        super(Gate, self).__init__()
        self.num_gates = num_gates

    def get_weight(self, x):
        raise NotImplementedError

    def get_mask(self):
        raise NotImplementedError

    def get_reg(self, base):
        raise NotImplementedError

    def get_num_active(self):
        return self.get_mask().sum().int().item()

    def forward(self, x):
        z = self.get_weight(x) #[num_gates]
        if len(x.size()) == 4:
            z = z.view(-1, self.num_gates, 1, 1) # scalar broadcast for * with conv(x);activation map
        else:
            z = z.view(-1, self.num_gates)

        z=z.cuda()
        return x*z
class GatedLayer(nn.Module):
    def __init__(self, num_gates):
        super(GatedLayer, self).__init__()
        self.num_gates = num_gates
        self.base = None
        self.gate = None
        self.dgate = None
        self.filtersize=None
    def build_gate(self, gate_fn, **kwargs):
        self.gate = gate_fn(self.num_gates, **kwargs)

    def build_gate_dep(self, dgate_fn, **kwargs):
        self.dgate = dgate_fn(self.num_gates, **kwargs)

    def apply_gate(self, x,S = None):
        if self.gate is None:
            return x
        else:
            if self.dgate is None:
                return self.gate(x)
            else:
                z = self.gate.get_weight(x)
                return self.dgate(x, z, S)

    def get_mask(self):
        return None if self.gate is None \
                else self.gate.get_mask()
    def get_mask_dep(self):
        return None if self.dgate is None \
                else self.dgate.get_mask_dep()

    def get_reg(self):
        return None if self.gate is None \
                else self.gate.get_reg(self.base)

    def get_reg_dep(self):
        return None if self.dgate is None \
                else self.dgate.get_reg(self.base)

    def get_num_active(self):
        return None if self.gate is None \
                else self.gate.get_num_active()
    def get_weight_nonactive(self):
        if self.gate is None:
            return None
        else:
            if self.dgate is None:
                return (self.num_gates-self.gate.get_num_active())*self.filtersize
            else:
                return (self.num_gates-self.dgate.num_active)*self.filtersize
    def get_weight(self):
        if self.gate is None:
            return None
        else:
            if self.dgate is None:
                return (self.gate.get_num_active()) * self.filtersize
            else:
                return (self.dgate.num_active)*self.filtersize
class GatedConv2d(GatedLayer):
    def __init__(self, in_channels, out_channels, kernel_size,
            stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(GatedConv2d, self).__init__(out_channels)
        self.filtersize=in_channels*kernel_size*kernel_size
        self.base = nn.Conv2d(in_channels, out_channels, kernel_size,
                stride=stride, padding=padding, dilation=dilation,
                groups=groups, bias=bias)

    def forward(self, x,S):

        S=F.conv2d(input=S,weight=self.base.weight,bias=self.base.bias,stride=self.base.stride,padding=self.base.padding,dilation=self.base.dilation,groups=self.base.groups)
        return self.apply_gate(self.base(x),S)

"""
debug
"""
class GatedConv2d2(GatedLayer):
    def __init__(self, in_channels, out_channels, kernel_size,
            stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(GatedConv2d2, self).__init__(out_channels)
        self.filtersize=in_channels*kernel_size*kernel_size
        self.base = nn.Conv2d(in_channels, out_channels, kernel_size,
                stride=stride, padding=padding, dilation=dilation,
                groups=groups, bias=bias)
        self.bn= nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
    def forward(self, x,S):
        S=F.conv2d(input=S,weight=self.base.weight,bias=self.base.bias,stride=self.base.stride,padding=self.base.padding,dilation=self.base.dilation,groups=self.base.groups)
        x,S=self.apply_gate(self.base(x),S)
        x=self.bn(x)
        x=self.relu(x)
        S=F.relu(S)
        return x,S
class GatedLinear(GatedLayer):
    def __init__(self, in_features, out_features, bias=True):
        super(GatedLinear, self).__init__(in_features)
        self.base = nn.Linear(in_features, out_features, bias=bias)


        self.bn = nn.BatchNorm1d(512)
        self.relu = nn.ReLU()
    def forward(self, x,S):
        x,S=self.apply_gate(x,S)
        S=F.linear(S,self.base.weight,self.base.bias)
        x=self.bn(self.base(x))
        x=self.relu(x)
        S=F.relu(S)
        return self.base(x), S
class last_GatedLinear(GatedLayer):
    def __init__(self, in_features, out_features, bias=True):
        super(last_GatedLinear, self).__init__(in_features)
        self.base = nn.Linear(in_features, out_features, bias=bias)

    def forward(self, x,S):
        x,_=self.apply_gate(x,S)
        return self.base(x)
