import numpy as np
import torch
import torch.nn as nn


class EqvarModuleMean(nn.Module):

     def __init__(self, in_dim, out_dim):
          super(EqvarModuleMean, self).__init__()
          self.in_dim = in_dim
          self.out_dim = out_dim

          self.gamma = nn.Linear(in_dim, out_dim)

     def forward(self, x):
          xm = x.mean(1, keepdim=True)
          x = self.gamma(x - xm)

          return x


class EqvarModuleMax(nn.Module):

     def __init__(self, in_dim, out_dim):
          super(EqvarModuleMax, self).__init__()
          self.in_dim = in_dim
          self.out_dim = out_dim

          self.gamma = nn.Linear(in_dim, out_dim)

     def forward(self, x):
          xm, _ = x.max(1, keepdim=True)
          x = self.gamma(x - xm)

          return x

class DatasetClassifier(nn.Module):

     def __init__(self, in_dim, hid_dim, out_dim, pool='mean'):
          super(DatasetClassifier, self).__init__()
          self.in_dim = in_dim
          self.hid_dim = hid_dim
          self.out_dim = out_dim

          if pool == 'mean':
               self.phi = nn.Sequential(
                    EqvarModuleMean(self.in_dim, self.hid_dim),
                    nn.Tanh(),
                    EqvarModuleMean(self.hid_dim, self.hid_dim),
                    nn.Tanh(),
                    EqvarModuleMean(self.hid_dim, self.hid_dim),
                    nn.Tanh()
               )
          else:
               self.phi = nn.Sequential(
                    EqvarModuleMax(self.in_dim, self.hid_dim),
                    nn.Tanh(),
                    EqvarModuleMax(self.hid_dim, self.hid_dim),
                    nn.Tanh(),
                    EqvarModuleMax(self.hid_dim, self.hid_dim),
                    nn.Tanh()
               )

          self.rho = nn.Sequential(
               nn.Dropout(p=0.5),
               nn.Linear(self.hid_dim, self.hid_dim),
               nn.Tanh(),
               nn.Dropout(p=0.5),
               nn.Linear(self.hid_dim, self.out_dim)
          )

     def forward(self, x):
          phi_out = self.phi(x)
          sum_output, _ = phi_out.max(1)
          rho_out = self.rho(sum_output)
          return rho_out

def clip_grad(model, max_norm):
    total_norm = 0
    for p in model.parameters():
        param_norm = p.grad.data.norm(2)
        total_norm += param_norm ** 2
    total_norm = total_norm ** (0.5)
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        for p in model.parameters():
            p.grad.data.mul_(clip_coef)
    return total_norm