from collections import defaultdict
import time

import torch
from torch import Tensor
from torch.nn import Module
from torch.utils.data.dataloader import DataLoader
from typing import Optional, Sequence, List, Dict, Tuple, Iterable

from utils import make_batch_one_hot, is_loss_criterion_vector_based
from sklearn.metrics import confusion_matrix
import numpy as np
from cl_tools import CumulativeStatistic, DatasetPart, DatasetType, ValidationResult, \
    INCProtocolIterator, TrainingStepResultBuilder, IDataset


def average_stats(stats: Iterable[CumulativeStatistic]) -> CumulativeStatistic:
    result: CumulativeStatistic = CumulativeStatistic()
    for stat in stats:
        result.update_using_averages(stat.average, count=stat.overall_count)
    return result


def get_validation_data(model: Module, test_dataset: IDataset,
                        device: Optional[torch.device] = None,
                        required_top_k: Optional[Sequence[int]] = None, return_detailed_outputs: bool = False,
                        criterion: Module = None, make_one_hot: Optional[bool] = None, n_classes: int = -1,
                        **kwargs) -> (Tensor, Optional[Dict[int, CumulativeStatistic]], Optional[Tensor], Optional[Tensor]):
    stats: Dict[int, CumulativeStatistic]
    max_top_k: int
    all_val_predictions_tmp: List[Tensor] = []
    all_val_labels_tmp: List[Tensor] = []
    all_val_predictions: Optional[Tensor] = None
    all_val_labels: Optional[Tensor] = None
    test_loss_per_class: Optional[Dict[int, CumulativeStatistic]] = None
    test_loader: DataLoader

    if make_one_hot is None:
        make_one_hot = is_loss_criterion_vector_based(criterion)

    if required_top_k is None:
        required_top_k = [1]

    max_top_k = max(required_top_k)

    stats = {}
    for top_k in required_top_k:
        stats[top_k] = CumulativeStatistic()

    if criterion is not None:
        # Enable test loss
        test_loss_per_class = defaultdict(CumulativeStatistic)

    if make_one_hot and n_classes <= 0:
        raise ValueError("n_class must be set when using one_hot_vectors")

    # noinspection PyTypeChecker
    test_loader = DataLoader(test_dataset, **kwargs)

    test_iter = 0
    print('Executing', len(test_loader), 'test iterations')
    start_eval_time = time.time()

    model.eval()
    with torch.no_grad():
        patterns: Tensor
        labels: Tensor
        targets: Tensor
        output: Tensor
        for patterns, labels in test_loader:
            print('x', end='')
            if ((test_iter + 1) % 30) == 0:
                print(' {:.2f}%'.format(100*(test_iter+1)/len(test_loader)))

            # Clear grad
            model.zero_grad()

            if return_detailed_outputs:
                all_val_labels_tmp.append(labels.detach().cpu())

            if make_one_hot:
                targets = make_batch_one_hot(labels, n_classes)
            else:
                targets = labels

            # Send data to device
            if device is not None:
                patterns = patterns.to(device)
                targets = targets.to(device)

            # Forward
            output = model(patterns)

            if criterion is not None:
                losses = criterion(output, targets).detach().cpu()
                for pattern_idx, loss_element in enumerate(losses):
                    test_loss_per_class[labels[pattern_idx].item()].update_using_counts(
                        loss_element.item(), count=1)

            output = output.detach().cpu()
            if return_detailed_outputs:
                all_val_predictions_tmp.append(torch.argmax(output, dim=1))

            # https://gist.github.com/weiaicunzai/2a5ae6eac6712c70bde0630f3e76b77b
            # Gets the indexes of max_top_k elements
            _, top_k_idx = output.topk(max_top_k, 1)
            top_k_idx = top_k_idx.t()

            # correct will have values True where index == label
            correct = top_k_idx.eq(labels.view(1, -1).expand_as(top_k_idx))
            for top_k in required_top_k:
                correct_k = correct[:top_k].view(-1).float().sum(0)  # Number of correct patterns for this top_k
                stats[top_k].update_using_counts(correct_k, len(labels))
            test_iter += 1

        if return_detailed_outputs:
            all_val_predictions = torch.cat(all_val_predictions_tmp)
            all_val_labels = torch.cat(all_val_labels_tmp)

    acc_results = torch.empty(len(required_top_k), dtype=torch.float)

    for top_idx, top_k in enumerate(required_top_k):
        acc_results[top_idx] = stats[top_k].average

    end_eval_time = time.time()
    print('Evaluation took', (end_eval_time - start_eval_time), ' seconds')

    return acc_results, test_loss_per_class, all_val_predictions, all_val_labels


def make_validation_data(required_top_k: Sequence[int], n_classes: int, task_info: INCProtocolIterator,
                         validation_type: DatasetPart, validation_dataset: DatasetType,
                         validation_result: Tuple[Tensor, Optional[Dict[int, CumulativeStatistic]],
                                                  Optional[Tensor], Optional[Tensor]]) -> ValidationResult:
    accuracies_top_k, loss_per_class, predictions, ground_truth = validation_result
    matrix = confusion_matrix(ground_truth.numpy(), predictions.numpy(), labels=list(range(n_classes)), normalize=None)

    den = matrix.sum(axis=1)[:, np.newaxis]
    out_matrix = np.zeros(matrix.shape, dtype=np.float)
    np.divide(matrix.astype('float'), den, where=den != 0, out=out_matrix)
    accuracy_per_class = out_matrix.diagonal()

    accuracies_top_k_dict = {}
    for idx, top_k in enumerate(required_top_k):
        accuracies_top_k_dict[top_k] = accuracies_top_k[idx].item()

    accuracy_per_class_dict = {}
    for class_idx in range(n_classes):
        accuracy_per_class_dict[class_idx] = accuracy_per_class[class_idx]

    return ValidationResult(
        task=task_info.current_task,
        accuracies_top_k=accuracies_top_k_dict,
        accuracy_per_class=accuracy_per_class_dict,
        confusion_matrix=matrix,
        loss=average_stats(loss_per_class.values()).average,
        loss_per_class=loss_per_class,
        validation_type=validation_type,
        validation_dataset=validation_dataset,
        task_info=task_info
    )


def cl_validation(model: Module, task_info: INCProtocolIterator, criterion: Module, total_classes: int,
                  device: Optional[torch.device] = None, required_top_k: Optional[Sequence[int]] = None,
                  make_one_hot: Optional[bool] = None, validate_on_validation_set: bool = True,
                  validate_on_training_set: bool = False, validate_on_part: DatasetPart = DatasetPart.COMPLETE_SET,
                  expanding_head_classes: int = -1, **data_loader_kwargs) -> Sequence[ValidationResult]:

    if required_top_k is None:
        required_top_k = [1]

    if expanding_head_classes < 0:
        expanding_head_classes = total_classes

    results: List[ValidationResult] = []

    if make_one_hot is None:
        make_one_hot = is_loss_criterion_vector_based(criterion)

    if validate_on_training_set:
        training_set = task_info.swap_transformations().get_training_set_part(validate_on_part)
        training_val_data = get_validation_data(model, training_set, device=device, required_top_k=required_top_k,
                                                return_detailed_outputs=True, criterion=criterion,
                                                make_one_hot=make_one_hot, n_classes=expanding_head_classes,
                                                **data_loader_kwargs)

        results.append(make_validation_data(required_top_k, total_classes, task_info,
                                            validate_on_part, DatasetType.TRAIN,
                                            training_val_data))

    if validate_on_validation_set:
        validation_set = task_info.get_test_set_part(validate_on_part)
        validation_val_data = get_validation_data(model, validation_set, device=device, required_top_k=required_top_k,
                                                  return_detailed_outputs=True, criterion=criterion,
                                                  make_one_hot=make_one_hot, n_classes=expanding_head_classes,
                                                  **data_loader_kwargs)

        results.append(make_validation_data(required_top_k, total_classes, task_info,
                                            validate_on_part, DatasetType.VALIDATION,
                                            validation_val_data))

    return results


def cl_training_update_metrics(result_builder: TrainingStepResultBuilder, losses: Tensor, predictions: Tensor,
                               ground_truth: Tensor) -> TrainingStepResultBuilder:
    losses = losses.detach().cpu()
    predictions = predictions.detach().cpu()
    ground_truth = ground_truth.detach().cpu()

    max_top_k: int = max(result_builder.required_accuracy_top_k)
    _, top_k_idx = predictions.topk(max_top_k, 1)
    top_k_idx = top_k_idx.t()

    top_k_stats: Dict[int, float] = {}
    loss_class_stats: Dict[int, CumulativeStatistic] = defaultdict(CumulativeStatistic)

    if len(losses.shape) > 1:
        losses = losses.mean(dim=1)

    # Top-k
    correct = top_k_idx.eq(ground_truth.view(1, -1).expand_as(top_k_idx))
    for top_k in result_builder.required_accuracy_top_k:
        correct_k = correct[:top_k].view(-1).float().sum(0)
        top_k_stats[top_k] = correct_k.item() / len(ground_truth)

    # Loss
    for pattern_idx in range(len(losses)):
        loss_class_stats[int(ground_truth[pattern_idx])].update_using_counts(losses[pattern_idx].item(), 1)

    result_builder.add_iteration_result(len(ground_truth), losses.mean(), top_k_stats, loss_class_stats)

    return result_builder
