import torch
import numpy as np
import torch.nn as nn
import os.path as osp

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def weights_init_encoder(m):
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        m.bias.data.fill_(0.0)
    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        assert m.weight.size(2) == m.weight.size(3)
        m.weight.data.fill_(0.0)
        m.bias.data.fill_(0.0)
        mid = m.weight.size(2) // 2
        gain = nn.init.calculate_gain("relu")
        nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)


# def var(tensor):
#     return tensor.to(device)


def get_tensor(z, device):
    if z is None:
        return None
    if z[0].dtype == np.dtype("O"):
        return None
    if len(z.shape) == 1:
        return torch.FloatTensor(z.copy()).to(device).unsqueeze(0)
        # return torch.from_numpy(z.copy()).float().to(device).unsqueeze(0)
    else:
        return torch.FloatTensor(z.copy()).to(device)
        # return torch.from_numpy(z.copy()).float().to(device)
