import numpy as np
import torch
# from torchmetrics import Metric

# class LocalAcc(Metric):
#     def __init__(self, batch_dim=0, target_dim=1, mask=None, device=None):
#         super().__init__()
#         self.batch_dim = batch_dim
#         self.target_dim=target_dim
#         self.device = device
#         if not(mask is None):
#             self.mask = torch.tensor(mask, device=self.device)
#         else:
#             self.mask = None
#         self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
#         self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

#     def update(self, preds: torch.Tensor, target: torch.Tensor):
#         assert preds.shape == target.shape

#         if self.mask is None:
#             self.correct += torch.sum(preds == target)
#             self.total += target.numel()
#         else:
#             correct = torch.all(torch.logical_or(torch.eq(outputs, labels), self.mask), dim=self.target_dim).to(torch.float)
#             self.correct += torch.sum(correct)
#             self.total += torch.masked_select(preds, mask).numel()

#     def compute(self):
#         return self.correct.float() / self.total

# class GlobalAcc(Metric):
#     def __init__(self, batch_dim=0, label_dim=1):
#         super().__init__()
#         self.batch_dim = batch_dim
#         self.label_dim=label_dim
#         self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
#         self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

#     def update(self, preds: torch.Tensor, target: torch.Tensor):
#         assert preds.shape == target.shape

#         self.correct += torch.sum(torch.all(torch.eq(outputs, labels), dim=self.label_dim).to(torch.float))
#         self.total += target.shape[self.batch_dim]

#     def compute(self):
#         return self.correct.float() / self.total

# class MarginalAcc(Metric):
#     def __init__(self, batch_dim=0, label_dim=1, mask=None):
#         super().__init__()
#         self.batch_dim = batch_dim
#         self.label_dim=label_dim
#         self.mask = mask
#         self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
#         self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

#     def update(self, preds: torch.Tensor, target: torch.Tensor):
#         assert preds.shape == target.shape
#         preds = torch.masked_select(preds, mask)
#         target = torch.masked_select(target, mask)

#         self.correct += torch.sum(preds == target)
#         self.total += target.numel()

#     def compute(self):
#         return self.correct.float() / self.total


def compute_metrics(scores, hex_outputs, labels, base_loss, hex_loss, log_diff, leaves=None, prefix="train", device=None):
    base_outputs = torch.gt(scores, 0)

    # compute outputs based metrics
    base_metrics = compute_output_metrics(outputs=base_outputs, labels=labels, names=['perfect_acc', 'cat_acc', 'leaves_acc', 'f1', 'precision', 'recall'], leaves=leaves, prefix="base", device=device)
    hex_metrics = compute_output_metrics(outputs=hex_outputs, labels=labels, names=['perfect_acc', 'cat_acc', 'leaves_acc', 'f1', 'precision', 'recall'], leaves=leaves, prefix="hex", device=device)

    # compute other metrics
    other_metrics = {}
    other_metrics["base_sup_hex"] = int(torch.all(torch.lt(log_diff, 0)).item())
    other_metrics["max_score"] = torch.max(scores).item()
    other_metrics["min_score"] = torch.min(scores).item()
    other_metrics["base_loss"] = base_loss.item()
    other_metrics["hex_loss"] = hex_loss.item()
    other_metrics["z_ratio"] = torch.mean(torch.exp(log_diff)).item()

    if not(prefix is None):
        metrics = {prefix+"/"+str(key): val for key, val in {**base_metrics,**hex_metrics, **other_metrics}.items()}
    else:
        metrics = {**base_metrics,**hex_metrics, **other_metrics}
    return metrics

def compute_output_metrics(outputs, labels, names, leaves=None, prefix="hex", device=None):
    
    pre = prefix+'_'
    metrics={}
    
    if 'perfect_acc' in names:
        metrics[pre+'perfect_acc']=perfect_acc(outputs, labels).item()

    if 'cat_acc' in names:
        metrics[pre+'cat_acc']=cat_acc(outputs, labels).item()

    if ('leaves_acc' in names) and not(leaves is None):
        metrics[pre+'leaves_acc']=masked_acc(outputs, labels, on=leaves, device=device).item()
    
    if 'f1' in names:
        metrics[pre+"f1"], precision, recall = f1(outputs, labels)
        if 'precision' in names:
            metrics[pre+'precision']=precision
        if 'recall' in names:
            metrics[pre+'recall']=recall
    else:    
        if 'precision' in names:
            metrics[pre+'precision']=precision(outputs, labels).item()
        if 'recall' in names:
            metrics[pre+'recall']=recall(outputs, labels).item()

    return metrics

def cat_acc(outputs, labels, reduction='mean'):
    correct = torch.eq(outputs, labels).to(torch.float)
    if reduction=='mean':
        acc = torch.mean(correct)
    elif reduction=='sum':
        acc = torch.sum(correct)
    else:
        return correct
    return acc

def perfect_acc(outputs, labels, cat_dim=1, reduction='mean'):
    correct = torch.all(torch.eq(outputs, labels), dim=cat_dim).to(torch.float)
    if reduction=='mean':
        acc = torch.mean(correct)
    elif reduction=='sum':
        acc = torch.sum(correct)
    else:
        return correct
    return acc

def masked_acc(outputs, labels, on=np.ones(1860), cat_dim=1, reduction='mean', device=None):
    mask = torch.logical_not(torch.tensor(on, device=device))
    correct = torch.all(torch.logical_or(torch.eq(outputs, labels), mask), dim=cat_dim).to(torch.float)
    if reduction=='mean':
        acc = torch.mean(correct)
    elif reduction=='sum':
        acc = torch.sum(correct)
    else:
        return correct
    return acc

def precision(outputs, labels, batch_dim=0, method='macro'):
    if method=="micro":
        tp = torch.sum(torch.logical_and(outputs, labels.to(torch.float)), dim=batch_dim)
        tfp = torch.sum(outputs, dim=batch_dim)
        precision = torch.mean(torch.div(tp, tfp))

    elif method=="macro":
        tp = torch.sum(torch.logical_and(outputs, labels.to(torch.float)))
        tfp = torch.sum(outputs)
        precision = torch.div(tp, tfp)

    else:
        return None

    return precision

def recall(outputs, labels, batch_dim=0, method='macro'):
    if method=="micro":
        tp = torch.sum(torch.logical_and(outputs, labels.to(torch.float)), dim=batch_dim)
        tpfn = torch.sum(labels.to(torch.float), dim=batch_dim)
        recall = torch.mean(torch.div(tp, tpfn))

    elif method=="macro":
        tp = torch.sum(torch.logical_and(outputs, labels.to(torch.float)))
        tpfn = torch.sum(labels.to(torch.float))
        recall = torch.div(tp, tpfn)

    else:
        return None

    return recall

def f1(outputs, labels, batch_dim=0, method='macro'):
    if method=="micro":
        tp = torch.sum(torch.logical_and(outputs, labels.to(torch.float)), dim=batch_dim)
        tfp = torch.sum(outputs, dim=batch_dim)
        tpfn = torch.sum(labels.to(torch.float), dim=batch_dim)
        precision = torch.mean(torch.div(tp, tfp)).item()
        recall = torch.mean(torch.div(tp, tpfn)).item()

    elif method=="macro":
        tp = torch.sum(torch.logical_and(outputs, labels.to(torch.float)))
        tfp = torch.sum(outputs)
        tpfn = torch.sum(labels.to(torch.float))
        precision = torch.div(tp, tfp).item()
        recall = torch.div(tp, tpfn).item()

    else:
        return None

    if (precision+recall)>0:
        f1 = 2*precision*recall/(precision+recall)
    else:
        f1=0

    return f1, precision, recall

