import loss_factory
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import utils

# import cvxpy as cp
# import time

# from train_loops.default import train_epoch_default
# from train_loops.drops import train_epoch_drops
# from train_loops.peer import train_epoch_peer

## Implement the different trainer


def train_control(full_package, loop_type):
    if loop_type == "default":
        if "trainer_default" not in globals():
            import train_loops.trainer_default as trainer_default
        return trainer_default.Trainer(full_package, loop_type)
    # TODO: Implement the peer loss
    elif loop_type == "peer":
        raise NotImplementedError
        # # if 'train_epoch_peer' not in globals(): from train_loops.peer import train_epoch_peer
        # if "trainer_peer" not in globals():
        #     import train_loops.trainer_peer as trainer_peer
        # return trainer_peer.Trainer_peer(full_package, loop_type)
    # TODO: Implement the drop loss
    elif loop_type == "drops":
        raise NotImplementedError
        if "train_epoch_drops" not in globals():
            from train_loops.drops import train_epoch_drops
        return train_epoch_drops(full_package)
