import os
import torch
from torch.utils.tensorboard import SummaryWriter
from typing import Tuple

@torch.no_grad()
def accuracy(logits, labels):
    pred = logits.max(1)[1]
    acc = pred.eq(labels).sum() / len(labels)
    return acc

def get_writer(args):
    """ Define tensorboard summarywriter """
    if not os.path.isdir(args.dataset):
        os.mkdir(args.dataset)
    log_name = f'MD_{args.model}_D_{args.depth}_N_{args.nprocs}_EP_{args.epoch}_DATA_{args.dataset}_SD_{args.seed}_LB_{args.label}'
    log_dir = os.path.join('.', args.dataset, log_name)
    writer = SummaryWriter(log_dir)
    return writer

def save_model(args, net):
    state = {
        'net': net.state_dict(),
        'best_val_acc': args.best_val_acc,
        'test_acc': args.test_acc,
        'best_epoch': args.best_epoch,
    }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    file_name = f'MD_{args.model}_D_{args.depth}_N_{args.nprocs}_EP_{args.epoch}_DATA_{args.dataset}_SD_{args.seed}_LB_{args.label}.pth'
    output_dir = os.path.join('./checkpoint', file_name)
    torch.save(state, output_dir)

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

class F1Score:
    """
    Class for f1 calculation in Pytorch.
    """

    def __init__(self, average: str = 'weighted'):
        """
        Init.

        Args:
            average: averaging method
        """
        self.average = average
        if average not in [None, 'micro', 'macro', 'weighted']:
            raise ValueError('Wrong value of average parameter')

    @staticmethod
    def calc_f1_micro(predictions: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Calculate f1 micro.

        Args:
            predictions: tensor with predictions
            labels: tensor with original labels

        Returns:
            f1 score
        """
        true_positive = torch.eq(labels, predictions).sum().float()
        f1_score = torch.div(true_positive, labels.numel())
        return f1_score

    @staticmethod
    def calc_f1_count_for_label(predictions: torch.Tensor,
                                labels: torch.Tensor, label_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Calculate f1 and true count for the label

        Args:
            predictions: tensor with predictions
            labels: tensor with original labels
            label_id: id of current label

        Returns:
            f1 score and true count for label
        """
        # label count
        true_count = torch.eq(labels, label_id).sum()

        # true positives: labels equal to prediction and to label_id
        true_positive = torch.logical_and(torch.eq(labels, predictions),
                                          torch.eq(labels, label_id)).sum().float()
        # precision for label
        precision = torch.div(true_positive, torch.eq(predictions, label_id).sum().float())
        # replace nan values with 0
        precision = torch.where(torch.isnan(precision),
                                torch.zeros_like(precision).type_as(true_positive),
                                precision)

        # recall for label
        recall = torch.div(true_positive, true_count)
        # f1
        f1 = 2 * precision * recall / (precision + recall)
        # replace nan values with 0
        f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1).type_as(true_positive), f1)
        return f1, true_count

    def __call__(self, predictions: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Calculate f1 score based on averaging method defined in init.

        Args:
            predictions: tensor with predictions
            labels: tensor with original labels

        Returns:
            f1 score
        """

        # simpler calculation for micro
        if self.average == 'micro':
            return self.calc_f1_micro(predictions, labels)

        f1_score = 0
        for label_id in range(len(labels.unique())):
            f1, true_count = self.calc_f1_count_for_label(predictions, labels, label_id)

            if self.average == 'weighted':
                f1_score += f1 * true_count
            elif self.average == 'macro':
                f1_score += f1

        if self.average == 'weighted':
            f1_score = torch.div(f1_score, len(labels))
        elif self.average == 'macro':
            f1_score = torch.div(f1_score, len(labels.unique()))

        return f1_score
