"""
Utility functions for Adaptive Gating MetaNet
"""

import os
import numpy as np
import torch


def assign_learning_rate(param_group, new_lr):
    """Assign new learning rate to parameter group"""
    param_group["lr"] = new_lr


def _warmup_lr(base_lr, warmup_length, step):
    """Calculate learning rate during warmup period"""
    return base_lr * (step + 1) / warmup_length


def cosine_lr(optimizer, base_lrs, warmup_length, steps):
    """Cosine learning rate scheduler with warmup

    Args:
        optimizer: The optimizer to update
        base_lrs: Base learning rate(s)
        warmup_length: Number of warmup steps
        steps: Total number of training steps

    Returns:
        Function to adjust learning rate at each step
    """
    if not isinstance(base_lrs, list):
        base_lrs = [base_lrs for _ in optimizer.param_groups]
    assert len(base_lrs) == len(optimizer.param_groups)

    def _lr_adjuster(step):
        for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
            if step < warmup_length:
                lr = _warmup_lr(base_lr, warmup_length, step)
            else:
                e = step - warmup_length
                es = steps - warmup_length
                lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
            assign_learning_rate(param_group, lr)

    return _lr_adjuster


def accuracy(output, target, topk=(1,)):
    """Compute accuracy for top-k predictions

    Args:
        output: Model output logits
        target: Ground truth labels
        topk: Tuple of k values for which to compute accuracy

    Returns:
        List of accuracies for each k
    """
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [
        float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
        for k in topk
    ]


def torch_save(model, save_path):
    """Save PyTorch model to disk

    Args:
        model: Model to save
        save_path: Path to save the model
    """
    if os.path.dirname(save_path) != "":
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save(model, save_path)


def torch_load(save_path, device=None):
    """Load PyTorch model from disk

    Args:
        save_path: Path to the saved model
        device: Device to load the model to

    Returns:
        Loaded model
    """
    model = torch.load(save_path, map_location="cpu")
    if device is not None:
        model = model.to(device)
    return model


def get_logits(inputs, classifier):
    """Get logits from classifier

    Args:
        inputs: Input features
        classifier: Classification model

    Returns:
        Output logits
    """
    assert callable(classifier)
    if hasattr(classifier, "to"):
        classifier = classifier.to(inputs.device)
    return classifier(inputs)


def get_probs(inputs, classifier):
    """Get probabilities from classifier

    Args:
        inputs: Input features
        classifier: Classification model

    Returns:
        Output probabilities
    """
    if hasattr(classifier, "predict_proba"):
        probs = classifier.predict_proba(inputs.detach().cpu().numpy())
        return torch.from_numpy(probs)
    logits = get_logits(inputs, classifier)
    return logits.softmax(dim=1)


class LabelSmoothing(torch.nn.Module):
    """Label smoothing loss function

    Args:
        smoothing: Smoothing parameter (0 for no smoothing)
    """

    def __init__(self, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing

    def forward(self, x, target):
        logprobs = torch.nn.functional.log_softmax(x, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()


class DotDict(dict):
    """Dictionary with dot notation access to attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__