import sys

import torch

from .base_learning import BaseLearning, assign_to_device
from ..loss.supervised_learning_loss import SupervisedLearningLoss
from ..tasks import Task
from ..utils.log import get_logger


logger = get_logger("supervised_learning")


class SupervisedLearning(BaseLearning):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def train(self, task: Task):
        self.task = task
        if self.digest is not None:
            self.digest.step = 0
        if self.task.config["model"]["name"] == "toy_svm" or self.task.config["model"]["name"] == "toy_svm_poly":
            if "loss_function" in self.task.config["loss"]:
                self.task.config["loss"]["loss_function"]["name"] = "hinge-loss"
            else:
                self.task.config["loss"]["loss_function"] = {"name": "hinge-loss"}
        self.task.loss = SupervisedLearningLoss(self.task.logdir, **self.task.config["loss"])
        self.set_model(task.model_to_train)
        return self.train_epochs()

    def train_epoch(self, epoch) -> None:
        self.model.train()

        # Adding functionality to add random noise to activations
        noise_level = self.task.config.get("controller", {}).get("noise_level", None)
        if noise_level is not None:
            noise_level_decay = self.task.config["controller"].get("noise_level_decay", 1.0)
            noise_level = noise_level * (noise_level_decay**epoch)
            self.model.set_noise_level(noise_level)
            logger.info(f"Current noise-level is: {noise_level}")

        with self.task.loss.new_epoch(epoch, "train"):
            dataloader = self.task.labeled_dataloader.dataloaders['train']

            name_list = []
            for name, param in self.model.named_parameters():
                name_list.append(name)

            for batch_idx, (data, target) in enumerate(dataloader):
                data = assign_to_device(data, self.device)
                target = target.to(self.device)
                if self.digest is not None:
                    self.digest.digest_batch_level(self.model, data, target, batch_idx, kind="before")

                self.optimizer.zero_grad()

                pred_target = self.model(data)

                loss = self.task.loss({"labeled_data_targets": target,
                                       "labeled_data_pred_targets": pred_target})
                loss.backward()
                self.optimizer.step()

                self.optimizer.zero_grad()

                if self.digest is not None:
                    self.digest.digest(batch_idx, self.model)
                    self.digest.digest_batch_level(self.model, data, target, batch_idx, kind="after")

                if self.task.config.get("optimal_path", False):
                    self.digest.digest(batch_idx, self.model, epoch=self.task.loss.epoch,
                                       variable_name=self.task.config["name"])

    def validate_epoch(self, epoch: int):
        self.model.eval()
        with self.task.loss.new_epoch(epoch, "val"), torch.no_grad():
            dataloader = self.task.labeled_dataloader.dataloaders['val']

            for batch_idx, (data, target) in enumerate(dataloader):
                data = assign_to_device(data, self.device)
                target = target.to(self.device)

                pred_target = self.model(data)

                self.task.loss({"labeled_data_targets": target,
                                "labeled_data_pred_targets": pred_target})

                # Progress bar
                sys.stdout.write('\r')
                sys.stdout.write('Validating ')
                sys.stdout.write("[%-70s] %d%%" % ('=' * ((batch_idx + 1) * 70 // len(dataloader)),
                                                    (batch_idx + 1) * 100 // len(dataloader)))
                sys.stdout.flush()
            sys.stdout.write('\n')

        self.controller.add_state(epoch, self.task.loss.get_epoch_loss(), self.model.state_dict())
