# /usr/bin/env python
# -*- coding: utf-8 -*-

import torch

from model.linear import Linear, Logistic
from model.cnn import CNN

def get_model(cfgs, info, device):
    if 'path' in cfgs['Model']:
        init = torch.load(cfgs['Model']['path'], map_location=cfgs['device'])
    else:
        init = None
    print('Init: {}'.format(init is not None))

    model_name = cfgs['Model'].get('type', 'Linear')
    print('Model structure: {}'.format(model_name))

    if model_name == 'Linear':
        model = Linear(
            input_dim=info['dim'],
            output_dim=info['cls_num'],
            R=cfgs['Online']['kwargs']['R'] / 2
        )
    elif model_name == 'CNN':
        model = CNN(
            num_input_channels=info['channel_num'],
            hid_dim=info['dim'],
            num_classes=info['cls_num']
        )
    else:
        raise NotImplementedError

    if init is not None:
        model.load_state_dict(init)

    return model.to(device), init


def get_cls_model(cfgs, info, device):
    init = None

    print('Init: {}'.format(init is not None))

    model_name = cfgs['Estimator'].get('type', 'Linear')
    print('Classification model: {}'.format(model_name))

    if model_name == 'Linear':
        model = Linear(
            input_dim=info['dim'],
            output_dim=info['cls_num'],
            R=cfgs['Online']['kwargs']['R'] / 2
        )
    elif model_name == 'CNN':
        model = CNN(
            num_input_channels=info['channel_num'],
            hid_dim=info['dim'],
            num_classes=info['cls_num']
        )
    else:
        raise NotImplementedError

    if init is not None:
        model.load_state_dict(init)

    return model.to(device), init


def get_est_model(cfgs, info, device):
    init = None

    print('Init: {}'.format(init is not None))

    model_name = cfgs['Estimator'].get('type', 'Linear')
    print('Estimator model: {}'.format(model_name))

    if model_name == 'Linear':
        model = Logistic(
            input_dim=info['dim'],
            output_dim=info['cls_num'],
            R=cfgs['Online']['kwargs']['R'] / 2
        )
    elif model_name == 'CNN':
        model = CNN(
            num_input_channels=info['channel_num'],
            hid_dim=info['dim'],
            num_classes=info['cls_num']
        )
    else:
        raise NotImplementedError

    if init is not None:
        model.load_state_dict(init)

    return model.to(device), init
