from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Dict, Optional

import torch
from torch.utils.data import DataLoader
from torchpercentile import Percentile

from ..models import pick_model
from ..tasks import Task
from ..utils import Controller
from ..utils.log import get_logger

logger = get_logger("base_learning")


class BaseLearning(ABC):
    def __init__(self, **kwargs):
        self.config = kwargs

        # Is overwritten in set_model_to_device()
        self.device: torch.device = torch.device("cpu:0")

        self.controller = Controller(**kwargs["controller"])

        self.model = None
        self.optimizer = None
        self.first_task = None

        self.task: Optional[Task] = None

    def reset(self, task=None):
        self.task = None
        self.model = None
        self.optimizer = None

        self.controller.reset(task=task)

    def set_model(self, model: Optional[torch.nn.Module] = None):
        if model is None or model == "pretrained":
            model_config = deepcopy(self.config["model"])

            if model is None:
                logger.warning(f"An untrained model is used for task {self.task.uid}")
                # we need to manually set the use_pretrained parameter to false just for this model config
                model_config["use_pretrained"] = False
            elif model == "pretrained":
                logger.warning(f"A pretrained model is used for task {self.task.uid}")
                # we need to manually set the use_pretrained parameter to true just for this model config
                model_config["use_pretrained"] = True

            model = pick_model(**self.config["model"])
            self.first_task = True
        else:
            self.first_task = False

        self.model = model
        self.set_model_to_device()
        self.pick_optimizer()

    def set_model_to_device(self):
        self.device = torch.device("cuda:0")
        self.model.to(self.device)

    def pick_optimizer(self):
        optimizer_config = self.config["optimizer"]
        task_lr = self.task.config.get("loss", {}).get("task_lr", None)

        if optimizer_config["optimizer_type"] == "SGD":
            optimizer_config_sgd = deepcopy(optimizer_config)

            if task_lr is not None:
                optimizer_config_sgd['lr'] = task_lr

            logger.info(f"Learning rate: {optimizer_config_sgd['lr']}")

            del optimizer_config_sgd["optimizer_type"]
            print(optimizer_config_sgd)
            self.optimizer = torch.optim.SGD(params=self.model.parameters(), **optimizer_config_sgd)

        else:
            raise NotImplementedError("Pick a valid optimizer")

    @abstractmethod
    def train(self, task: Task):
        pass

    def train_epoches(self):
        # Passing task to allow controller to change if specified for individual task
        self.controller.reset(task=self.task)

        logger.info(f"Running {self.task.type} task {self.task.name}")

        if bool(self.controller.checkpoint) is True:
            save_every = self.controller.checkpoint.get("save_every", 1)
            checkpoint_dir = str(self.controller.checkpoint["checkpoint_dir"])

        if bool(self.controller.lr_decay) is True:
            decay_rate = self.controller.lr_decay["decay_rate"]
            decay_every_n_epochs = self.controller.lr_decay["decay_every_n_epochs"]

        self.validate_epoch(-1)  # validate the model once before any training occurs.

        for epoch in self.controller:
            self.train_epoch(epoch)
            self.validate_epoch(epoch)

            if bool(self.controller.checkpoint) is True:
                if epoch % save_every == 0:
                    current_lr = float(self.optimizer.param_groups[0]['lr'])
                    self.task.save_checkpoint(self.model, checkpoint_dir, epoch, current_lr)

            if bool(self.controller.lr_decay) is True:
                self.optimizer.param_groups[0]['lr'] = self.optimizer.defaults['lr'] * (decay_rate ** (epoch // decay_every_n_epochs))

        best_dict = self.controller.get_best_state()["model_dict"]
        self.model.load_state_dict(best_dict)
        self.task.save_model(self.model)

        self.test()

        return self.model

    @abstractmethod
    def train_epoch(self, epoch) -> None:
        pass

    @abstractmethod
    def validate_epoch(self, epoch: int) -> None:
        pass

    def test(self):
        self.model.eval()
        with self.task.loss.new_epoch(0, "test"), torch.no_grad():
            if self.task.type == "supervised-learning":
                dataloader = self.task.labeled_dataloader.dataloaders['test']
            else:
                raise NotImplementedError(f"The following task type is not implemented: {self.task.type}")

            for batch_idx, (data, target) in enumerate(dataloader):
                data, target = data.to(self.device), target.to(self.device)

                pred_target = self.model(data)
                self.task.loss({"labeled_data_targets": target,
                                "labeled_data_pred_targets": pred_target})

    def infer_labels(self, certainty_threshold: float = 0, percentile_rank: float = 0) -> Dict[str, torch.Tensor]:
        # if certainty threshold is set to 0, all labels are inferred
        # all max labels which don't meet the certainty threshold, will get a label of -1

        if self.task.inference_dataloader is not None:
            inference_dataloader = self.task.inference_dataloader
        else:
            inference_dataloader = self.task.unlabeled_dataloader

        labels = {}
        for dataloader_type, dataloader_random_sampler in inference_dataloader.dataloaders.items():
            dataloader_sequential_sampler = DataLoader(dataloader_random_sampler.dataset,
                                                       batch_size=dataloader_random_sampler.batch_size,
                                                       shuffle=False,
                                                       num_workers=dataloader_random_sampler.num_workers)
            num_samples = len(dataloader_sequential_sampler.dataset)

            with torch.no_grad():
                # Important for batch normalization
                self.model.eval()

                logger.info(f"Start inferring labels for {dataloader_type} dataset")

                max_pred_probabilities = None
                max_predictions = None
                for batch_idx, (data, target) in enumerate(dataloader_sequential_sampler):
                    start_idx = batch_idx*dataloader_sequential_sampler.batch_size

                    data, target = data.to(self.device), target.to(self.device)

                    pred_target = self.model(data)

                    # we need to apply softmax to get normalized class probabilities
                    softmax = torch.nn.Softmax(dim=1)
                    pred_target_softmaxed = softmax(pred_target)

                    # returns the indice of the class with the highest probability and its associated certainty for
                    # all images in the batch
                    batch_max_pred_probabilities, batch_max_predictions = torch.max(pred_target_softmaxed, 1)

                    if max_pred_probabilities is None:
                        max_pred_probabilities = batch_max_pred_probabilities.new_zeros(size=(num_samples, ))
                    max_pred_probabilities[start_idx:start_idx + batch_max_pred_probabilities.size(0)] \
                        = batch_max_pred_probabilities

                    if max_predictions is None:
                        max_predictions = batch_max_predictions.new_zeros(size=(num_samples, ))
                    max_predictions[start_idx:start_idx + batch_max_predictions.size(0)] \
                        = batch_max_predictions

                percentiles = Percentile()(max_pred_probabilities, [percentile_rank])
                threshold: float = max(certainty_threshold, percentiles[0].item())
                max_predictions[max_pred_probabilities < threshold] = -1

                labels[dataloader_type] = max_predictions

                logger.info(f"Finished inferring {len(labels[dataloader_type])} labels in batches for the "
                            f"{dataloader_type} dataset. "
                            f"Number of inferred labels meeting the threshold: "
                            f"{torch.sum(max_pred_probabilities < threshold).item()}")

        return labels
