import torch
import torch.nn as nn

from .decoders import register
from ..modules import *

__all__ = ['FCdecoder']


@register('FC_decoder')

class FCdecoder(Module):

  def __init__(self, in_dim, out_dim, bn_args=None, temp=1., learn_temp=False):
    super(FCdecoder, self).__init__()
    self.in_dim = in_dim
    self.out_dim = out_dim
    self.temp = temp
    self.learn_temp = learn_temp
    bn_args['episodic'] = False

    self.linear1 = Linear(in_dim, out_dim)
    # self.bn1 = BatchNorm1d(out_dim, **bn_args)
    self.Relu1 = nn.ReLU(inplace=True)
    # self.linear2 = Linear(out_dim, out_dim)
    # self.bn2 = BatchNorm1d(out_dim, **bn_args)
    # self.Relu2 = nn.ReLU(inplace=True)
    if self.learn_temp:
      self.temp = nn.Parameter(torch.tensor(temp))


  def forward(self, x_shot, params=None, episode=None):
    assert x_shot.dim() == 2
    out = self.linear1(x_shot, get_child_dict(params, 'linear1'))
    # out = self.bn1(out, get_child_dict(params, 'bn1'), episode)
    out = self.Relu1(out)
    # out = self.linear2(out, get_child_dict(params, 'linear2'))
    # out = self.bn2(out, get_child_dict(params, 'bn2'), episode)
    # out = self.Relu2(out)
    return out