from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Dict, Optional

import torch
from torch.utils.data import DataLoader
from ..models import pick_model
from .learning_intervention import intervention
from ..tasks import Task
from ..utils import Controller, freeze_model_layer
from ..utils.log import get_logger
from ..utils.digest import TensorboardDigest
from ..utils.lr_scheduler import get_scheduler


logger = get_logger("base_learning")


class BaseLearning(ABC):
    def __init__(self, **kwargs):
        self.config = kwargs
        self.user_device = self.config.get("cluster", "cuda")
        # Is overwritten in set_model_to_device()
        self.device: torch.device = torch.device("cpu:0")

        self.digest_config: Dict = self.config.get("digest", {})

        self.digest = None

        self.model = None
        self.optimizer = None
        self.scheduler = None

        self.task: Optional[Task] = None

    def reset(self, task=None):
        self.task = None
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.controller = None

    def set_model(self, model: Optional[torch.nn.Module] = None):
        if model is None or model == "pretrained":
            model_config = deepcopy(self.task.config["model"])

            if model is None:
                logger.warning(f"An untrained model is used for task {self.task.uid}")
                # we need to manually set the use_pretrained parameter to false just for this model config
                model_config["use_pretrained"] = False
            elif model == "pretrained":
                logger.warning(f"A pretrained model is used for task {self.task.uid}")
                # we need to manually set the use_pretrained parameter to true just for this model config
                model_config["use_pretrained"] = True

            model = pick_model(**model_config)

        model = intervention(model=model, task=self.task, logger=logger, device=self.device)

        self.model = model
        self.set_model_to_device()
        self.pick_optimizer()

    def set_model_to_device(self):
        self.device = torch.device(self.user_device + ":0")
        self.model.to(self.device)

    def pick_optimizer(self):
        scheduler_config = self.task.config.get("scheduler", None)

        task_lr = self.task.config.get("loss", {}).get("task_lr", None)

        optimizer_config = self.task.config.get("optimizer", self.task.config["optimizer"])

        if optimizer_config["optimizer_type"] == "SGD":
            optimizer_config_sgd = deepcopy(optimizer_config)

            if task_lr is not None:
                optimizer_config_sgd['lr'] = task_lr

            logger.info(f"Learning rate: {optimizer_config['lr']}")
            del optimizer_config_sgd["optimizer_type"]
            freeze_layer = optimizer_config_sgd.pop("freeze_layer", None)
            if freeze_layer is not None:
                self.model = freeze_model_layer(self.model, freeze_layer)

            if hasattr(self.model, "get_parameters"):
                model_parameters = self.model.get_parameters()
            else:
                model_parameters = self.model.parameters()
            self.optimizer = torch.optim.SGD(params=model_parameters, **optimizer_config_sgd)

            if scheduler_config is not None:
                scheduler_name = scheduler_config.get("name", None)
                self.scheduler = get_scheduler(scheduler_name, self.optimizer, kwargs=scheduler_config)
        else:
            raise NotImplementedError("Pick a valid optimizer")

    @abstractmethod
    def train(self, task: Task):
        pass

    def train_epochs(self):
        # Passing task to allow controller to change if specified for individual task
        self.controller = Controller(**self.task.config["controller"])

        if len(self.digest_config) > 0:
            # Initialize digest based on current task
            self.digest = TensorboardDigest(logdir=self.task.logdir, **self.digest_config)

        if self.digest is not None:

            self.digest.loss = self.task.loss.loss_functions['train']["callable"]
            # Note that "dataloader" contains both training and validation data
            self.digest.dataloader = DataLoader(self.task.labeled_dataloader.dataloaders['train'].dataset,
                                                batch_size=32, shuffle=False, num_workers=2)
            self.digest.train_dataloader = self.task.labeled_dataloader.dataloaders['train']
            self.digest.test_dataloader = self.task.labeled_dataloader.dataloaders['test']
            self.digest.device = self.device
            self.digest.__enter__()

        logger.info(f"Running {self.task.type} task {self.task.name}")

        if bool(self.controller.checkpoint) is True:
            save_every = self.controller.checkpoint.get("save_every", 1)
            checkpoint_base_dir = str(self.controller.checkpoint["checkpoint_dir"])

        self.validate_epoch(-1)  # validate the model once before any training occurs.
        if bool(self.controller.checkpoint) is True:
            current_lr = float(self.optimizer.param_groups[0]['lr'])
            self.task.save_checkpoint(self.model, checkpoint_base_dir, -1, current_lr)

        for epoch in self.controller:
            current_lr = float(self.optimizer.param_groups[0]['lr'])
            logger.info(f"Epoch {epoch} and learning rate: {current_lr}")
            self.train_epoch(epoch)
            self.validate_epoch(epoch)

            if bool(self.controller.checkpoint) is True:
                if epoch % save_every == 0:
                    current_lr = float(self.optimizer.param_groups[0]['lr'])
                    self.task.save_checkpoint(self.model, checkpoint_base_dir, epoch, current_lr)
            # Updating learning rate with learning rate scheduler
            if self.scheduler is not None:
                self.scheduler.step()

        best_dict = self.controller.get_best_state()["model_dict"]
        self.model.load_state_dict(best_dict)
        self.task.save_model(self.model)

        self.test()

        if self.digest is not None:
            self.digest.__exit__()

        return self.model

    @abstractmethod
    def train_epoch(self, epoch: int) -> None:
        pass


    @abstractmethod
    def validate_epoch(self, epoch: int) -> None:
        pass

    def test(self):
        self.model.eval()
        with self.task.loss.new_epoch(0, "test"), torch.no_grad():
            if self.task.type == "supervised-learning":
                dataloader = self.task.labeled_dataloader.dataloaders['test']
            elif self.task.type == "supervised-grad-reg-learning":
                dataloader = self.task.labeled_dataloader.dataloaders['test']
            elif self.task.type == "semi-supervised-learning":
                dataloader = self.task.semi_supervised_dataloader.dataloaders['test']
            elif self.task.type in ["domain-confusion", "domain-adaptation-mdd", "domain-adaptation-dann"]:
                dataloader = self.task.unlabeled_dataloader.dataloaders['test']
            else:
                raise NotImplementedError(f"The following task type is not implemented: {self.task.type}")

            for batch_idx, (data, target) in enumerate(dataloader):
                data = assign_to_device(data, self.device)
                target = target.to(self.device)

                pred_target = self.model(data)
                self.task.loss({"labeled_data_targets": target,
                                "labeled_data_pred_targets": pred_target})

    def infer_labels(self, certainty_threshold: float = 0, percentile_rank: float = 0) -> Dict[str, torch.Tensor]:
        # if certainty threshold is set to 0, all labels are inferred
        # all max labels which don't meet the certainty threshold, will get a label of -1

        if self.task.inference_dataloader is not None:
            inference_dataloader = self.task.inference_dataloader
        else:
            inference_dataloader = self.task.unlabeled_dataloader

        labels = {}
        for dataloader_type, dataloader_random_sampler in inference_dataloader.dataloaders.items():
            dataloader_sequential_sampler = DataLoader(dataloader_random_sampler.dataset,
                                                       batch_size=dataloader_random_sampler.batch_size,
                                                       shuffle=False,
                                                       num_workers=dataloader_random_sampler.num_workers)
            num_samples = len(dataloader_sequential_sampler.dataset)

            with torch.no_grad():
                # Important for batch normalization
                self.model.eval()

                logger.info(f"Start inferring labels for {dataloader_type} dataset")

                max_pred_probabilities = None
                max_predictions = None
                for batch_idx, (data, target) in enumerate(dataloader_sequential_sampler):
                    start_idx = batch_idx*dataloader_sequential_sampler.batch_size

                    data, target = data.to(self.device), target.to(self.device)

                    pred_target = self.model(data)

                    # we need to apply softmax to get normalized class probabilities
                    softmax = torch.nn.Softmax(dim=1)
                    pred_target_softmaxed = softmax(pred_target)

                    # returns the indice of the class with the highest probability and its associated certainty for
                    # all images in the batch
                    batch_max_pred_probabilities, batch_max_predictions = torch.max(pred_target_softmaxed, 1)

                    if max_pred_probabilities is None:
                        max_pred_probabilities = batch_max_pred_probabilities.new_zeros(size=(num_samples, ))
                    max_pred_probabilities[start_idx:start_idx + batch_max_pred_probabilities.size(0)] \
                        = batch_max_pred_probabilities

                    if max_predictions is None:
                        max_predictions = batch_max_predictions.new_zeros(size=(num_samples, ))
                    max_predictions[start_idx:start_idx + batch_max_predictions.size(0)] \
                        = batch_max_predictions

                # our config specifies percentiles, but torch works with quantiles
                q = max_pred_probabilities.new_tensor([percentile_rank / 100])
                quantiles = torch.quantile(max_pred_probabilities, q)
                threshold: float = max(certainty_threshold, quantiles.item())
                max_predictions[max_pred_probabilities < threshold] = -1

                labels[dataloader_type] = max_predictions

                logger.info(f"Finished inferring {len(labels[dataloader_type])} labels in batches for the "
                            f"{dataloader_type} dataset. "
                            f"Number of inferred labels meeting the threshold: "
                            f"{torch.sum(max_pred_probabilities < threshold).item()}")

        return labels


def assign_to_device(data, device):
    if isinstance(data, tuple):
        data = [subdata.to(device) for subdata in data]
    else:
        data = data.to(device)

    return data
