import torch.nn as nn
from moduleloader import outervar
import torch

def send_net_to_device(state, event, net):
    state["device_data"] = "cuda:"+str(state.all["gpu"]) if state.all["gpu"] >= 0 else "cpu"
    state["device_labels"] = state["device_data"]
    state["loss_labels"] = state["device_data"]
    if state["dataparallel"]:
        net = torch.nn.DataParallel(net)
    return net.to(state["device_data"])

def send_data_to_device(state, event, input):
    return input.to(state["device_data"], non_blocking=True)

def send_loss_to_device(state, event, loss):
    return loss.to(state["device_data"], non_blocking=True)

def send_labels_to_device(state, event, labels):
    return labels.to(state["device_labels"], non_blocking=True)


def register(mf):
    mf.register_defaults({
        "dataparallel": False
    })
    mf.register_globals({
        "gpu": 0,
    })

    mf.register_event('send_net_to_device', send_net_to_device, unique=True)
    mf.register_event('send_data_to_device', send_data_to_device, unique=True)
    mf.register_event('send_labels_to_device', send_labels_to_device, unique=True)
    mf.register_event('send_loss_to_device', send_loss_to_device, unique=True)

