import torch

class MetricAggregator:
    def __init__(self, metrics,print_all_tasks=False):
        self.metrics = metrics
        self.print_all_tasks = print_all_tasks
    def update(self, outputs, target,tasks=None):
        #assert len(self.metrics) == len(outputs) == len(target)
        if len(self.metrics) == 1:
            self.metrics[0].update(outputs, target)
            return
        if tasks is None:
            for t in range(len(self.metrics)):
                self.metrics[t].update(outputs[t], target[t])
        else:
            for t in range(len(self.metrics)):
                task_idx = tasks == t
                self.metrics[t].update(outputs[task_idx], target[task_idx])
                
    def compute(self):
        
        return [m.compute() for m in self.metrics]

    def get_mean(self):
        if type(self.metrics[0].compute()) == dict: 
            all_computes = self.compute()
            means = {}
            for key in all_computes[0].keys():
                means[key] = torch.mean(torch.hstack([c[key] for c in all_computes]))
            return means
        else:
            return float(torch.mean(torch.hstack(self.compute())).item())
             
    def get_string(self):
        string = ""
        mets = self.get_mean()
        if type(mets) == dict:
            for key in mets.keys():
                string += f"{key}: {mets[key]:.4f} "
        else:
            string += f"{mets:.4f} "
        return string

    def get_tasks_string(self):
        string = ""
        mets = self.compute()
        for task in range(len(mets)):
            string += f"Task {task}: {mets[task].mean():.4f} "
        
        return string

class F1:
    def __init__(self,ignore_val=-1,ret_metric="f1") -> None:
        self.loss_sums = 0
        self.ignore_val = ignore_val
        self.ret_preds = []
        self.ret_to_pred = []
        self.ret_well_pred = []
        self.ret_accs = []
        self.ret_metric = ret_metric
        
    def update(self,outputs: torch.Tensor, target: torch.Tensor):
        preds = (torch.sigmoid(outputs)>0.5).type(torch.float32) 
        mask = (target != self.ignore_val).all(1).bool()
        preds_masked = preds[mask]
        targets_masked = target[mask]
        self.ret_preds.append(torch.sum(preds_masked.cpu(), dim=0))
        self.ret_to_pred.append(torch.sum(targets_masked.cpu(), dim=0))
        self.ret_well_pred.append(torch.sum((preds_masked*targets_masked).cpu(), dim=0))
    
    def compute(self):
        ret_preds = torch.stack(self.ret_preds).sum(axis=0)
        ret_to_pred = torch.stack(self.ret_to_pred).sum(axis=0)
        ret_well_pred = torch.stack(self.ret_well_pred).sum(axis=0)        
        val_precs = ret_well_pred / (ret_preds + 1e-7)
        val_recs = ret_well_pred / (ret_to_pred + 1e-7)
        val_fscores = 2*val_precs*val_recs/(val_precs+val_recs+1e-7)
        return {"prec":val_precs,"rec": val_recs,"f1":val_fscores}



class Accuracy:
    def __init__(self, ignore_val=-1):
        self.ignore_val = ignore_val
        self.corrects = []
        self.preds = []
        self.targets = []
        
    def update(self, outputs, targets):
        if targets.numel()<1:
            return
        self.preds.append(outputs.argmax(dim=1))
        if not (targets.numel() == 1):
            targets = targets.squeeze()
        self.targets.append(targets)
        self.corrects.append((outputs.argmax(dim=1) == targets).type(torch.float32).mean())
    
    def compute(self):
        return torch.mean(torch.vstack(self.corrects))
    
    def get_per_class_accuracy(self):
        preds = torch.cat(self.preds)
        targets = torch.cat(self.targets)
        accs = []
        
        for c in torch.unique(targets):
            accs.append(torch.mean((preds[targets == c] == c).type(torch.float32)))
        return torch.stack(accs), torch.unique(targets)