import io

from .dndt import *
from .dtn import *
from .losses import *
from .mlp import *
from .sdt import *


class ModelWrapper:
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.checkpoint = None

    def forward(self, *args):
        return self.model.forward(*args)

    def train(self):
        self.model.train()

    def eval(self):
        self.model.eval()

    def parameters(self):
        return self.model.parameters()

    def decision_paths(self, x):
        return self.model.decision_paths(x)

    def count_parameters(self):
        return sum(p.numel() for p in self.model.parameters() if p.requires_grad)

    def cache(self):
        self.checkpoint = io.BytesIO()
        torch.save(self.model.state_dict(), self.checkpoint)

    def restore(self):
        self.checkpoint.seek(0)
        self.model.load_state_dict(torch.load(self.checkpoint))

    def save(self, path):
        torch.save(self.model.state_dict(), path)

    def load(self, path, device):
        self.model.load_state_dict(torch.load(path, map_location=device))


def to_model(args, dataset, in_features, num_classes, device):
    if args.model in ['DTN', 'DTN-D']:
        model = DTN(in_features, num_classes,
                    num_layers=args.layers,
                    num_units=args.units,
                    activation=args.activation,
                    width=args.width,
                    prune=args.prune)
    elif args.model == 'DTN-S':
        model = DTNS(in_features, num_classes,
                     num_layers=args.layers,
                     num_units=args.units)
    elif args.model == 'DNDT':
        model = DNDT(in_features, num_classes,
                     dataset=dataset,
                     selection='../data/uci-features.tsv')
    elif args.model == 'MLP':
        model = MLP(in_features, num_classes,
                    num_layers=args.layers)
    elif args.model == 'SDT':
        model = SDT(in_features, num_classes,
                    num_layers=args.layers)
    else:
        raise ValueError(args.model)
    # noinspection PyUnresolvedReferences
    return ModelWrapper(model.to(device))
