import torch
import torch.nn as nn
import random
import os
import numpy as np
from sklearn.metrics import average_precision_score


def mixup_data(x, y, l):
  """Returns mixed inputs, pairs of targets, and lambda"""
  indices = torch.randperm(x.shape[0]).to(x.device)

  mixed_x = l * x + (1 - l) * x[indices]
  y_a, y_b = y, y[indices]
  return mixed_x, y_a, y_b


def mixup_criterion(criterion, pred, y_a, y_b, l):
  return l * criterion(pred, y_a) + (1 - l) * criterion(pred, y_b)


def seed_all(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def freeze_except_fc(model):
    update_params = []
    for n, p in model.named_parameters():
        if 'head.conv' in n or 'fc' in n or 'heads' in n:
            update_params += [p]
        else:
            p.requires_grad = False
    return update_params


def get_per_class_accuracy(args, dataset):
    '''Returns the custom per_class_accuracy function. When using this custom function
    look at only the validation accuracy. Ignore trainig set accuracy.
    '''
    def _get_class_weights(args, dataset):
        '''Returns the distribution of classes in a given dataset.
        '''
        if args.dataset in ['pets', 'flowers']:
            targets = dataset.targets
        elif args.dataset in ['ncaltech101']:
            return dataset.class_weights

        elif args.dataset in ['caltech101', 'caltech256']:
            targets = np.array([dataset.ds.dataset.y[idx]
                                for idx in dataset.ds.indices])

        elif args.dataset == 'aircraft':
            targets = [s[1] for s in dataset.samples]

        counts = np.unique(targets, return_counts=True)[1]
        class_weights = counts.sum()/(counts*len(counts))
        return torch.Tensor(class_weights)

    def get_multiclass_average_precision(y_true, y_scores):
        """
        Get average precision score between 2 1-d numpy arrays

        Args:
            y_true: batch of true labels
            y_scores: batch of confidence scores

        Returns:
            sum of batch average precision
        """
        scores = 0.0

        for i in range(y_true.shape[0]):
            scores += average_precision_score(y_true=y_true[i], y_score=y_scores[i])

        return float(scores)

    @torch.no_grad()
    def custom_acc(logits, labels):
        '''Returns the top1 accuracy, weighted by the class distribution.
        This is important when evaluating an unbalanced dataset.
        '''
        if args.dataset == 'pascalvoc':
            return get_multiclass_average_precision(labels.cpu().detach().numpy(),
                                                    torch.sigmoid(logits).detach().numpy())
        maxk = 1
        _, pred = logits.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(labels.view(1, -1).expand_as(pred))
        prec1 = correct[:1].reshape(-1).float()

        if args.dataset in ['pets', 'flowers', 'caltech101', 'caltech256', 'aircraft', 'ncaltech101']:
            class_weights = _get_class_weights(args, dataset)
            weighted_prec1 = prec1 * class_weights[labels.cpu()]
            weighted_prec1 = weighted_prec1.sum(0, keepdim=True)
            return weighted_prec1.item()
        else:
            normal_prec1 = prec1.sum(0, keepdim=True)
            return normal_prec1.item()

    return custom_acc