import torch
import torch.nn as nn
from torch.nn.init import orthogonal_, constant_
import numpy as np


class MLP(nn.Module):
    def __init__(self, sizes, non_lin=nn.Tanh()):
        super().__init__()
        ops = []
        for i_size, o_size in zip(sizes[:-1], sizes[1:]):
            lin = nn.Linear(i_size, o_size)
            # orthogonal_(lin.weight)
            nn.init.xavier_uniform_(lin.weight, gain=nn.init.calculate_gain(non_lin._get_name().lower()))
            constant_(lin.bias, 0.01)
            ops.append(lin)
            ops.append(non_lin)
            # ops.append(nn.LayerNorm(o_size))

        self.f = nn.Sequential(*ops[:-1])

    def forward(self, x):
        return self.f(x)

    def get_features(self, x):
        for f in self.f[:-1]:
            x = f(x)
        return x

    def lin(self, feat):
        return self.f[-1](feat)


def copy_model(a, b):
    for pa, pb in zip(a.parameters(), b.parameters()):
        pa.data = pb.detach().clone()

