import os, sys
import torch
import numpy as np
import pickle
import igraph as ig
import matplotlib.pyplot as plt

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
Tensor = FloatTensor
# Tensor = torch.float16

def get_loss_scalar(loss_var):
    if float(torch.__version__[:3]) >= 0.4:
        return loss_var.detach().cpu().numpy()
    else:
        if type(loss_var) is torch.autograd.Variable:
            return loss_var.data.cpu().numpy()[0]
        else: 
            return loss_var.cpu().numpy()[0]

def convert_tensor_to_np(tensor):
    '''This function right now assumes that the tensor is an image'''
    data = tensor.cpu().numpy()
    data = np.squeeze(data)
    data = np.transpose(data, (0, 2, 3, 1))
    return data

def get_tensor(arr, data_type=Tensor):
    ''' arr is a numpy array'''
    return torch.from_numpy(arr).type(Tensor)

def square_norm(u):
  return torch.dot(u, u)


def safe_rnorm(u):
  epsilon = 1e-10
  return torch.rsqrt(square_norm(u) + epsilon)

def to_np(tensor, detach=False):
  if detach:
    return tensor.detach().cpu().numpy()
  else:
    return tensor.cpu().numpy()

def load_from_pickle(filename):
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    return data

def save_as_pickle(data, filename, folder_path):
    name = os.path.join(folder_path, filename)
    with open(name, 'wb') as f:
        pickle.dump(data, f)

def print_graph(g):
    fig, ax = plt.subplots()
    ig.plot(g, target=ax)