'''
Reference:
https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_linear.py
'''
import math
import torch
from torch import nn
from torch.nn import functional as F


class SimpleLinear(nn.Module):
  '''
  Reference:
  https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py
  '''

  def __init__(self, in_features, out_features, bias=True):
    super(SimpleLinear, self).__init__()
    self.in_features = in_features
    self.out_features = out_features
    self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
    if bias:
      self.bias = nn.Parameter(torch.Tensor(out_features))
    else:
      self.register_parameter('bias', None)
    self.reset_parameters()

  def reset_parameters(self):
    nn.init.kaiming_uniform_(self.weight, nonlinearity='linear')
    if self.bias is not None:
      nn.init.constant_(self.bias, 0)

  def forward(self, input):
    return {'logits': F.linear(input, self.weight, self.bias)}


class CosineLinear(nn.Module):
  def __init__(self, in_features, out_features, nb_proxy=1, to_reduce=False, sigma=True):
    super(CosineLinear, self).__init__()
    self.in_features = in_features
    self.out_features = out_features * nb_proxy
    self.nb_proxy = nb_proxy
    self.to_reduce = to_reduce
    self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features))
    if sigma:
      self.sigma = nn.Parameter(torch.Tensor(1))
    else:
      self.register_parameter('sigma', None)
    self.reset_parameters()

  def reset_parameters(self):
    stdv = 1. / math.sqrt(self.weight.size(1))
    self.weight.data.uniform_(-stdv, stdv)
    if self.sigma is not None:
      self.sigma.data.fill_(1)

  def forward(self, input):
    out = F.linear(F.normalize(input, p=2, dim=1),
                   F.normalize(self.weight, p=2, dim=1))

    if self.to_reduce:
      # Reduce_proxy
      out = reduce_proxies(out, self.nb_proxy)

    if self.sigma is not None:
      out = self.sigma * out

    return {'logits': out}


class SplitCosineLinear(nn.Module):
  def __init__(self, in_features, out_features1, out_features2, nb_proxy=1, sigma=True):
    super(SplitCosineLinear, self).__init__()
    self.in_features = in_features
    self.out_features = (out_features1 + out_features2) * nb_proxy
    self.nb_proxy = nb_proxy
    self.fc1 = CosineLinear(in_features, out_features1, nb_proxy, False, False)
    self.fc2 = CosineLinear(in_features, out_features2, nb_proxy, False, False)
    if sigma:
      self.sigma = nn.Parameter(torch.Tensor(1))
      self.sigma.data.fill_(1)
    else:
      self.register_parameter('sigma', None)

  def forward(self, x):
    out1 = self.fc1(x)
    out2 = self.fc2(x)

    # concatenate along the channel
    out = torch.cat((out1['logits'], out2['logits']), dim=1)

    # Reduce_proxy
    out = reduce_proxies(out, self.nb_proxy)

    if self.sigma is not None:
      out = self.sigma * out

    return {
        'old_scores': reduce_proxies(out1['logits'], self.nb_proxy),
        'new_scores': reduce_proxies(out2['logits'], self.nb_proxy),
        'logits': out
    }


def reduce_proxies(out, nb_proxy):
  if nb_proxy == 1:
    return out
  bs = out.shape[0]
  nb_classes = out.shape[1] / nb_proxy
  assert nb_classes.is_integer(), 'Shape error'
  nb_classes = int(nb_classes)

  simi_per_class = out.view(bs, nb_classes, nb_proxy)
  attentions = F.softmax(simi_per_class, dim=-1)

  return (attentions * simi_per_class).sum(-1)


'''
class CosineLinear(nn.Module):
    def __init__(self, in_features, out_features, sigma=True):
        super(CosineLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if sigma:
            self.sigma = nn.Parameter(torch.Tensor(1))
        else:
            self.register_parameter('sigma', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.sigma is not None:
            self.sigma.data.fill_(1)

    def forward(self, input):
        out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1))
        if self.sigma is not None:
            out = self.sigma * out
        return {'logits': out}


class SplitCosineLinear(nn.Module):
    def __init__(self, in_features, out_features1, out_features2, sigma=True):
        super(SplitCosineLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features1 + out_features2
        self.fc1 = CosineLinear(in_features, out_features1, False)
        self.fc2 = CosineLinear(in_features, out_features2, False)
        if sigma:
            self.sigma = nn.Parameter(torch.Tensor(1))
            self.sigma.data.fill_(1)
        else:
            self.register_parameter('sigma', None)

    def forward(self, x):
        out1 = self.fc1(x)
        out2 = self.fc2(x)

        out = torch.cat((out1['logits'], out2['logits']), dim=1)  # concatenate along the channel
        if self.sigma is not None:
            out = self.sigma * out

        return {
            'old_scores': out1['logits'],
            'new_scores': out2['logits'],
            'logits': out
        }
'''
