import torch

from .base_learning import BaseLearning
from ..loss.supervised_learning_loss import SupervisedLearningLoss
from ..tasks import Task
from ..utils.log import get_logger
import sys

logger = get_logger("supervised_learning")


class SupervisedLearning(BaseLearning):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def train(self, task: Task):
        self.task = task
        if self.config["model"]["name"] == "toy_svm" or self.config["model"]["name"] == "toy_svm_poly":
            if "loss_function" in self.task.config["loss"]:
                self.task.config["loss"]["loss_function"]["name"] = "hinge-loss"
            else:
                self.task.config["loss"]["loss_function"] = {"name": "hinge-loss"}
        self.task.loss = SupervisedLearningLoss(self.task.logdir, **self.task.config["loss"])

        self.set_model(task.model_to_train)
        return self.train_epoches()

    def train_epoch(self, epoch) -> None:
        self.model.train()
        with self.task.loss.new_epoch(epoch, "train"):
            dataloader = self.task.labeled_dataloader.dataloaders['train']

            for batch_idx, (data, target) in enumerate(dataloader):
                self.optimizer.zero_grad()

                data, target = data.to(self.device), target.to(self.device)

                pred_target = self.model(data)
                loss = self.task.loss({"labeled_data_targets": target,
                                       "labeled_data_pred_targets": pred_target})
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

    def validate_epoch(self, epoch: int):
        self.model.eval()
        with self.task.loss.new_epoch(epoch, "val"), torch.no_grad():
            dataloader = self.task.labeled_dataloader.dataloaders['val']
            for batch_idx, (data, target) in enumerate(dataloader):
                data, target = data.to(self.device), target.to(self.device)

                pred_target = self.model(data)
                self.task.loss({"labeled_data_targets": target,
                                "labeled_data_pred_targets": pred_target})

                # Progress bar
                sys.stdout.write('\r')
                sys.stdout.write('Validating ')
                sys.stdout.write("[%-70s] %d%%" % ('=' * ((batch_idx + 1)*70//len(dataloader)), (batch_idx + 1) * 100 // len(dataloader)))
                sys.stdout.flush()
            sys.stdout.write('\n')

        self.controller.add_state(epoch, self.task.loss.get_epoch_loss(), self.model.state_dict())