import numpy as np
from typing import Dict

import torch
from torch.autograd import Variable
from tqdm import tqdm


def show_stats(data: np.ndarray) -> Dict[str, float]:
    """
    Show basic statistics in the input array.

    :param data: input data in a numpy array.

    :return results in a form of a dict.
    """
    min_score = np.min(data)
    max_score = np.max(data)
    median = np.median(data)
    mean = np.mean(data)
    std = np.std(data)
    sum = np.sum(data)
    count = len(data)

    results = {
        'min_score': min_score,
        'max_score': max_score,
        'median': median,
        'mean': mean,
        'std': std,
        'count': count,
        'sum': sum,
    }

    for key, value in results.items():
        print(f'{key}: {value}')

    return results


if __name__ == "__main__":
    show_stats(data=np.array([1, 2, 3]))


def show_thresholding(threshold, results, data_type=""):
    if threshold is not None:
        above_threshold = results[results > threshold]
        nr_scores_above = len(above_threshold)
        total_nr_scores = len(results)
        rate = nr_scores_above / total_nr_scores
        if data_type is not None:
            data_type_str = f'Data type: {data_type}. '
        print(
            f'{data_type_str}'
            f'Number of scores above the threshold: {nr_scores_above}. '
            f'Total number of scores: {total_nr_scores}. '
            f'Rate of results above threshold: {rate}. '
        )


def generate_scores(model, CUDA_DEVICE, data_loader, title='Testing'):
    model.eval()
    num_batches = len(data_loader)
    results = []
    data_iter = tqdm(data_loader)

    for j, (images, targets) in enumerate(data_iter):
        data_iter.set_description(
            f'{title} | Processing image batch {j + 1}/{num_batches}')
        images = Variable(images.to(CUDA_DEVICE), requires_grad=True)
        scores = model.get_scores(images=images)
        results.extend(scores)

    data_iter.set_description(
        f'{title} | Processing image batch {num_batches}/{num_batches}')

    data_iter.close()

    return np.array(results)


def get_model_accuracy(model, data_iter, CUDA_DEVICE):
    num_examples = 0
    correct = 0
    for images, targets in data_iter:
        images = images.to(CUDA_DEVICE)
        targets = targets.to(CUDA_DEVICE)
        acc_logits = model(images)

        num_examples += len(acc_logits)
        predicted = torch.argmax(acc_logits, dim=1)
        correct += (predicted == targets).int().sum()

    accuracy = 100 * correct / num_examples
    return accuracy
