import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

MAX_ITER = 1
EVAL_MAX_ITER = 100



class SpectralNormPowerMethod(nn.Module):

  def __init__(self, input_size, eps=1e-8):
    super(SpectralNormPowerMethod, self).__init__()
    self.input_size = input_size
    self.eps = eps
    self.u = torch.randn(input_size)
    self.u = self.u / self.u.norm(p=2)
    self.u = nn.Parameter(self.u, requires_grad=False)

  def normalize(self, arr):
    norm = torch.sqrt((arr ** 2).sum())
    return arr / (norm + 1e-12)

  def _compute_dense(self, M, max_iter):
    """Compute the largest singular value with a small number of
    iteration for training"""
    for _ in range(max_iter):
      v = self.normalize(F.linear(self.u, M))
      self.u.data = self.normalize(F.linear(v, M.T))
    z = F.linear(self.u, M)
    sigma = torch.mul(z, v).sum()
    return sigma

  def _compute_conv(self, kernel, max_iter):
    """Compute the largest singular value with a small number of
    iteration for training"""
    pad = (1, 1, 1, 1)
    pad_ = (-1, -1, -1, -1)
    for i in range(max_iter):
      v = self.normalize(F.conv2d(F.pad(self.u, pad), kernel))
      self.u.data = self.normalize(F.pad(F.conv_transpose2d(v, kernel), pad_))
    u_hat, v_hat = self.u, v
    z = F.conv2d(F.pad(u_hat, pad), kernel)
    sigma = torch.mul(z, v_hat).sum()
    return sigma

  def forward(self, M, max_iter):
    """ Return the highest singular value of a matrix
    """
    if len(M.shape) == 4:
      return self._compute_conv(M, max_iter)
    elif len(M.shape) == 2:
      return self._compute_dense(M, max_iter)


def get_activation(config):
  if config.activation == 'relu':
    return nn.ReLU(inplace=False)
  elif config.activation == 'leaky_relu':
    return nn.LeakyReLU(negative_slope=0.1)
  elif config.activation == 'elu':
    return nn.ELU()
  elif config.activation == 'silu':
    return nn.SiLU()

def init_weights(config, weights, type_weights):
  # initialize weights and biases
  if config.init == 'original':
    nn.init.kaiming_uniform_(weights, a=np.sqrt(5))
  elif config.init == 'kaiming_uniform':
    nn.init.kaiming_uniform_(weights)
  elif config.init == 'kaiming_normal':
    nn.init.kaiming_normal_(weights)
  elif config.init == 'xavier_uniform':
    nn.init.xavier_uniform_(weights)
  elif config.init == 'xavier_normal':
    nn.init.xavier_normal_(weights)
  elif config.init == 'zeros':
    nn.init.zeros_(weights)
  elif config.init == 'dirac':
    if type_weights == 'conv':
      nn.init.dirac_(weights)
    elif type_weights == 'linear':
      nn.init.eye_(weights)
  else:
    ValueError("init not recognized")





class ConvexPotentialLayerConv(nn.Module):

  def __init__(self, config, input_size, cin, cout, kernel_size=3, stride=1, epsilon=1e-4):
    super(ConvexPotentialLayerConv, self).__init__()

    self.activation = nn.ReLU(inplace=False)
    self.stride = stride
    self.register_buffer('eval_sv_max', torch.Tensor([0]))

    self.kernel = torch.zeros(cout, cin, kernel_size, kernel_size)
    self.bias = torch.zeros(cout)
    self.kernel = nn.Parameter(self.kernel)
    self.bias = nn.Parameter(self.bias)

    self.pm = SpectralNormPowerMethod(input_size)
    self.train_max_iter = MAX_ITER
    self.eval_max_iter = EVAL_MAX_ITER

    # initialize weights and biases
    nn.init.kaiming_uniform_(self.kernel, a=math.sqrt(5))  # weight init
    fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.kernel)
    bound = 1 / math.sqrt(fan_in)
    nn.init.uniform_(self.bias, -bound, bound)  # bias init

    self.epsilon = epsilon

  def forward(self, x):
    res = F.conv2d(x, self.kernel, bias=self.bias, stride=self.stride, padding=1)
    res = self.activation(res)
    res = F.conv_transpose2d(res, self.kernel, stride=self.stride, padding=1)
    if self.training == True:
      self.eval_sv_max -= self.eval_sv_max
      sv_max = self.pm(self.kernel, self.train_max_iter)
      h = 2 / (sv_max ** 2 + self.epsilon)
    else:
      if self.eval_sv_max == 0:
        self.eval_sv_max += self.pm(self.kernel, self.eval_max_iter)
      h = 2 / (self.eval_sv_max ** 2 + self.epsilon)
    out = x - h * res
    return out


class ConvexPotentialLayerLinear(nn.Module):

  def __init__(self, config, cin, cout, epsilon=1e-4):
    super(ConvexPotentialLayerLinear, self).__init__()
    self.activation = nn.ReLU(inplace=False)
    self.register_buffer('eval_sv_max', torch.Tensor([0]))

    self.weights = torch.zeros(cout, cin)
    self.bias = torch.zeros(cout)
    self.weights = nn.Parameter(self.weights)
    self.bias = nn.Parameter(self.bias)

    self.pm = SpectralNormPowerMethod((1, cin))
    self.train_max_iter = MAX_ITER
    self.eval_max_iter = EVAL_MAX_ITER

    # initialize weights and biases
    nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5))  # weight init
    fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights)
    bound = 1 / math.sqrt(fan_in)
    nn.init.uniform_(self.bias, -bound, bound)  # bias init

    self.epsilon = epsilon

  def forward(self, x):
    res = F.linear(x, self.weights, self.bias)
    res = self.activation(res)
    res = F.linear(res, self.weights.t())
    if self.training == True:
      self.eval_sv_max -= self.eval_sv_max
      sv_max = self.pm(self.weights, self.train_max_iter)
      h = 2 / (sv_max ** 2 + self.epsilon)
    else:
      if self.eval_sv_max == 0:
        self.eval_sv_max += self.pm(self.weights, self.eval_max_iter)
      h = 2 / (self.eval_sv_max ** 2 + self.epsilon)
    out = x - h * res
    return out



class SDPBasedLipschitzConvLayer(nn.Module):

  def __init__(self, config, input_size, cin, cout, kernel_size=3, stride=1, epsilon=1e-6):
    super(SDPBasedLipschitzConvLayer, self).__init__()

    self.activation = nn.ReLU(inplace=False)
    self.stride = stride

    self.kernel = torch.randn(cout, cin, kernel_size, kernel_size)
    self.bias = torch.randn(cout)
    self.q = torch.randn(cout)
    self.kernel = nn.Parameter(self.kernel)
    self.bias = nn.Parameter(self.bias)
    self.q = nn.Parameter(self.q)

    self.activation = get_activation(config)

    init_weights(config, self.kernel, 'conv')
    fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.kernel)
    bound = 1 / np.sqrt(fan_in)
    nn.init.uniform_(self.bias, -bound, bound) # bias init

    self.epsilon = epsilon

  def forward(self, x):
    res = F.conv2d(x, self.kernel, bias=self.bias, stride=self.stride, padding=1)
    res = self.activation(res)
    batch_size, cout, x_size, x_size = res.shape
    kkt = F.conv2d(self.kernel, self.kernel, padding=self.kernel.shape[-1] - 1)
    q_abs = torch.abs(self.q)
    T = 2 / (torch.abs(q_abs[None, :, None, None] * kkt).sum((1, 2, 3)) / q_abs)
    res = T[None, :, None, None] * res
    res = F.conv_transpose2d(res, self.kernel, stride=self.stride, padding=1)
    out = x - res
    return out  


class SDPBasedLipschitzLinearLayer(nn.Module):

  def __init__(self, config, cin, cout, epsilon=1e-6):
    super(SDPBasedLipschitzLinearLayer, self).__init__()
    self.activation = nn.ReLU(inplace=False)

    self.weights = torch.randn(cout, cin)
    self.bias = torch.randn(cout)
    self.q = torch.rand(cout)
    self.weights = nn.Parameter(self.weights)
    self.bias = nn.Parameter(self.bias)
    self.q = nn.Parameter(self.q)

    self.activation = get_activation(config)

    init_weights(config, self.weights, 'linear')
    fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights)
    bound = 1 / np.sqrt(fan_in)
    nn.init.uniform_(self.bias, -bound, bound)  # bias init

    self.epsilon = epsilon

  def forward(self, x):
    res = F.linear(x, self.weights, self.bias)
    res = self.activation(res)
    q_abs = torch.abs(self.q)
    q = q_abs[None, :]
    q_inv = (1/(q_abs+self.epsilon))[:, None]
    T = 2/torch.abs(q_inv * self.weights @ self.weights.T * q).sum(1)
    res = T * res
    res = F.linear(res, self.weights.t())
    out = x - res
    return out



class PaddingChannels(nn.Module):

  def __init__(self, ncout, ncin=3, mode="zero"):
    super(PaddingChannels, self).__init__()
    self.ncout = ncout
    self.ncin = ncin
    self.mode = mode

  def forward(self, x):
    if self.mode == "clone":
      return x.repeat(1, int(self.ncout / self.ncin), 1, 1) / np.sqrt(int(self.ncout / self.ncin))
    elif self.mode == "zero":
      bs, _, size1, size2 = x.shape
      out = torch.zeros(bs, self.ncout, size1, size2, device=x.device)
      out[:, :self.ncin] = x
      return out


class PoolingLinear(nn.Module):

  def __init__(self, ncin, ncout, agg="mean"):
    super(PoolingLinear, self).__init__()
    self.ncout = ncout
    self.ncin = ncin
    self.agg = agg

  def forward(self, x):
    if self.agg == "trunc":
      return x[:, :self.ncout]
    k = 1. * self.ncin / self.ncout
    out = x[:, :self.ncout * int(k)]
    out = out.view(x.shape[0], self.ncout, -1)
    if self.agg == "mean":
      out = np.sqrt(k) * out.mean(axis=2)
    elif self.agg == "max":
      out, _ = out.max(axis=2)
    return out


class LinearNormalized(nn.Linear):

  def __init__(self, in_features, out_features, bias=True):
    super(LinearNormalized, self).__init__(in_features, out_features, bias)

  def forward(self, x):
    self.Q = F.normalize(self.weight, p=2, dim=1)
    return F.linear(x, self.Q, self.bias)


