from models.utils import standard_train
from models.basenet import BaseNet
from importlib import import_module


class baseline(BaseNet):
    def __init__(self, opt, wandb):
        super(baseline, self).__init__(opt, wandb)
        self.set_network(opt)
        self.set_optimizer(opt)

    def set_network(self, opt):
        """Define the network"""
        
        if not self.is_3d:
            mod = import_module("models.basemodels")
            cusModel = getattr(mod, self.backbone)
            self.network = cusModel(n_classes=self.output_dim, pretrained=self.pretrained).to(self.device)
            if self.ssl_pretrained:
                self.network = self.load_moco(self.network)
            
        else:
            mod = import_module("models.basemodels_3d")
            cusModel = getattr(mod, self.backbone)
            self.network = cusModel(n_classes=self.output_dim, pretrained = self.pretrained).to(self.device)
        
        #self.network = cusResNet18(n_classes=self.output_dim, pretrained=self.pretrained).to(self.device)

    def forward(self, x):
        out, feature = self.network(x)
        return out, feature

    def set_optimizer(self, opt):
        optimizer_setting = opt['optimizer_setting']
        self.optimizer = optimizer_setting['optimizer'](
            params=filter(lambda p: p.requires_grad, self.network.parameters()),
            lr=optimizer_setting['lr'],
            weight_decay=optimizer_setting['weight_decay']
        )
    
    def _train(self, loader):
        """Train the model for one epoch"""

        self.network.train()
        auc, train_loss = standard_train(self.opt, self.network, self.optimizer, loader, self._criterion, self.wandb)

        print('Training epoch {}: AUC:{}'.format(self.epoch, auc))
        print('Training epoch {}: loss:{}'.format(self.epoch, train_loss))
        
        self.epoch += 1
    