
from .callback.core import * 
from .callback.tracking import * 
from .callback.scheduler import *

from .learner import Learner
from .datacore import *

from sklearn.base import BaseEstimator
from torch.optim import Adam


class SKLearner(Learner):

    def __init__(self, dataset, dataloader, model, 
                    loss_func=None, lr=1e-3, 
                    cbs=None, metrics=None, opt_func=Adam, 
                    scoring=None, **kwargs):                

        self.dataset, self.dataloader = dataset, dataloader
        self.model = model
        self.loss_func, self.lr = loss_func, lr
        self.cbs, self.metrics = cbs, metrics
        self.opt_func = opt_func
        self.scoring = scoring
        # super(SKLearner, self).__init__(dls=None, model=model, **kwargs)
                

    def get_dls(self, X_train, y_train, X_valid=None, y_valid=None, batch_size=64):
        train_ds = self.dataset(X_train, y_train)
        train_dl = self.dataloader(train_ds, batch_size=batch_size)
        if X_valid: 
            valid_ds = self.dataset(X_valid, y_valid)
            valid_dl = self.dataloader(valid_ds, batch_size=batch_size)
        return DataLoaders(train_dl, valid_dl) if valid_dl else DataLoaders(train_dl)
        

    def fit(self, X_train, y_train, X_valid=None, y_valid=None, batch_size=64, **fit_params):

        dls = self.get_dls(X_train, y_train, X_valid, y_valid, batch_size)
        self.learner = Learner(dls, self.model, 
                        loss_func=self.loss_func, lr=self.lr, 
                        cbs=self.cbs, metrics=self.metrics, 
                        opt_func=self.opt_func)
        suggested_lr = suggested_lr.lr_finder(suggestion='valley', show_plot=False)
        self.learner.fit_one_cycle(n_epochs=0, lr_max=suggested_lr)
        

    def predict(self, X_test, y_test=None):
        pass

    def score(self, X, y, sample_weight=None):
        dl = self.get_dls(X, y)
        preds = self.predict(dl)
        