# /usr/bin/env python
# -*- coding: utf-8 -*-

from online.models.iwerm import CustomOGD, Fix

def get_classifier_alg(cfgs, model, train_set, device, rng, info):

    alg_kwargs = {
        'model': model,
        'dataset': train_set,
        'device': device,
        'rng': rng,
        'batch_size': cfgs['kwargs']['source_batch_size']
    }
    if cfgs['algorithm'] == 'OGD':
        alg_kwargs.update({
            'stepsize': cfgs['kwargs']['lr'],
            'projection': cfgs['kwargs']['projection'],
            'dim': info['dim']
        })
        algorithm = CustomOGD(cfgs=cfgs['kwargs'], **alg_kwargs)

        return algorithm
    elif cfgs['algorithm'] == 'FIX':
        algorithm = Fix(cfgs=cfgs['kwargs'], **alg_kwargs)

        return algorithm
    else:
        raise NotImplementedError
