import os
from .mlp import MLP
from .gbdt import GBDT
from datasets import Dataset
from market.specifications import Specification
from utils import set_seed, select_device

class Learnware:
    def __init__(self, cfg, learnware_id, cfe=None):
        set_seed(cfg['seed'] + learnware_id)
        self.cfg = cfg
        self.task = cfg['task']
        self.learnware_id = learnware_id
        if self.task == 'classification':
            self.model = MLP (cfg, learnware_id, cfe)
        elif self.task == 'regression':
            self.model = GBDT(cfg, learnware_id, cfe)
        self.load()
        set_seed(self.cfg['seed'])
        kwargs = self.spec_kwargs()
        self.spec = Specification(cfg, 'learnware', learnware_id, **kwargs)

    def __call__(self, x):
        return self.model(x)

    def spec_kwargs(self):
        spec = self.cfg['specification']
        return {
            'path': os.path.join(self.cfg['dataset_path'], 'specifications', spec, 'learnware', f'{self.learnware_id}.npz'),
            'device': select_device(self.cfg, self.learnware_id),
            'learnware': self.model,    # For LinearProxy
            'phi_path': os.path.join(self.cfg['dataset_path'], 'phi', 'learnware', f'{self.learnware_id}.pt')  # For neural phi
        }

    def train(self):
        if self.cfg['retrain'] or not os.path.exists(self.model.path):
            dataset = Dataset(self.cfg, 'learnware', self.learnware_id)
            trainloader = dataset.get_loader('train')
            evalloader  = dataset.get_loader('eval')
            self.model.train(trainloader, evalloader)

    def evaluate(self, dataloader, pred=False, prob=False):
        return self.model.evaluate(dataloader, pred=pred, prob=prob)

    def specification(self):
        return self.spec

    def save(self):
        self.model.save()

    def load(self):
        if self.cfg['retrain'] or not self.model.load():
            self.train()