#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Metrics to evaluate models. Builds off the torchmetrics module as it is
useful for distributed learning.

"""

import torchmetrics
from inspect import signature
import torch
        
class TemplateMetric(torchmetrics.Metric):
    def __init__(self):
        super().__init__()
        self.quantity = 0
        # If one wanted to compute said metric once between resets, regardless
        #   of how many updates are called, one could set the default value as
        #   None and only perform the update if the stored quantity is None
        
    def update(self, model, outputs, targets):
        """
        Calculate and internally store the metric
        """
        self.quantity += 1
        

    def compute(self):
        """
        Compute the metric.
        """
        return self.quantity

    def reset(self):
        """
        Reset the metric at the start of each batch or epoch.
        """
        self.quantity = 0
    
    
class SmartMetricCollection(torchmetrics.MetricCollection):
    """
    MetricCollection is a useful way to evaluate numerous Metrics from
    torchmetrics.
    However, by default, said Metrics take arguments (preds,target) e.g.
    to calculate accuracy or F1 score. We would like to include custom Metrics
    that take in the model as an input e.g. compute the sparsity of the model.
    Thus, we need a custom MetricCollection that distinguishes which Metrics
    take the standard (preds,target) inputs, and which take model as an input.
    """
    def __init__(self, metrics: dict, prefix: str = "", postfix: str = "", compute_groups: bool = True):
        super().__init__(metrics, prefix=prefix, postfix=postfix, compute_groups=compute_groups)
        self._input_kinds = self._inspect_input_kinds()

    def _inspect_input_kinds(self):
        """
        Determine whether each metric requires 'model' as an input
        """
        input_kinds = {}
        for name, metric in self.items():
            sig = signature(metric.update)
            params = list(sig.parameters.keys())
            if "model" in params:
                input_kinds[name] = "model"
            else:
                input_kinds[name] = "preds_target"
        return input_kinds
    
    
    def compute(self):
        """
        Ensure all dictionary-returning-metrics have results prefixed with the
        name of the metric. Otherwise, if only one Metric in the collection
        returned a dictionary, it would be flattened without any prefixing.
        """
        output = {}
        for name, metric in self.items():
            result = metric.compute()
            if isinstance(result, dict):
                for k, v in result.items():
                    output[f"{name}_{k}"] = v
            else:
                output[name] = result
        return output
    
    
    def __call__(self, preds=None, target=None, model=None):
        self.update(preds=preds, target=target, model=model)
        return self.compute()
    
    def update(self, preds=None, target=None, model=None):
        for name, metric in self.items():
            input_kind = self._input_kinds[name]
            if input_kind == "model":
                if model is None:
                    raise ValueError(f"Metric '{name}' requires a model, but none was provided.")
                metric.update(model=model)
            elif input_kind == "preds_target":
                if preds is None or target is None:
                    raise ValueError(f"Metric '{name}' requires preds and target.")
                metric.update(preds, target)
            else:
                raise ValueError(f"Unknown input kind for metric '{name}': {input_kind}")
        return self


#%% Different metrics used to evaluate model. Some return a dictionary  
      
class ModelSparsityMetric(torchmetrics.Metric):
    """
    Compute the overall sparsity % of a model. If masks and weights were
    learned separately, computes the effective sparsity of each mask * weight.
    """
    def __init__(self, model):
        super().__init__()
        self.total_params = self._calc_total_params(model)
        self.sparsity = None
        
    def _calc_total_params(self, model):
        total_params = 0
        for n,p in model.named_parameters():
                total_params += p.numel()
        return total_params
        
    def update(self, model):
        """
        Update the metric based on the model parameters.
        The model is passed here, and we count the number of non-zero parameters.
        """
        if not self.sparsity: # Do not recompute multiple times
            total_nonzero = 0
            for name, param in model.named_parameters():
                    total_nonzero += param.count_nonzero()
            
            self.sparsity = 100 * (1 - total_nonzero/self.total_params)
    
    def compute(self):
        """
        Compute the metric.
        """
        return self.sparsity

    def reset(self):
        """
        Reset the metric at the start of each batch or epoch.
        Allow for the recomputation of the model sparsity
        """
        self.sparsity = None
        
        
        
class LinearSparsityMetric(torchmetrics.Metric):
    def __init__(self):
        super().__init__()
        self.add_state("sparsity", default=torch.tensor(float('nan')), persistent=False)
        self.add_state("has_updated", default=torch.tensor(False), persistent=False)

    def update(self, model):
        if self.has_updated:
            return

        numel = 0
        nnz = 0
        for m in model.modules():
            if isinstance(m, torch.nn.Linear):
                a = m.weight
                numel_loc = a.numel()
                numel += numel_loc
                nnz += torch.count_nonzero(a).item()

        self.sparsity = 100*(1 - (nnz / numel)) if numel > 0 else torch.tensor(0.0)
        self.has_updated = torch.tensor(True)

    def compute(self):
        return self.sparsity

    def reset(self):
        self.has_updated = torch.tensor(False)
        self.sparsity = torch.tensor(float('nan'))
        
class ConvSparsityMetric(torchmetrics.Metric):
    """
    Count how many active kernels are used, that is, convolutional kernels that
    are not entirely zero.
    """
    
    def __init__(self):
        super().__init__()
        self.add_state("sparsity", default=torch.tensor(float('nan')), persistent=False)
        self.add_state("has_updated", default=torch.tensor(False), persistent=False)

    def update(self, model):
        if self.has_updated:
            return

        nnz = 0
        total = 0
        for m in model.modules():
            if isinstance(m, torch.nn.Conv2d):
                s = m.weight.shape
                w = m.weight.view(s[0] * s[1], s[2] * s[3])
                nnz += torch.count_nonzero(torch.norm(w, p=1, dim=1) > 0).item()
                total += s[0] * s[1]

        self.sparsity = 100*(1-nnz / total) if total > 0 else torch.tensor(0.0)
        self.has_updated = torch.tensor(True)

    def compute(self):
        return self.sparsity

    def reset(self):
        self.has_updated = torch.tensor(False)
        self.sparsity = torch.tensor(float('nan'))


class LayerSparsityMetric(torchmetrics.Metric):
    """
    Compute the sparsity % of each individual parameter of a model, returned
    as a dictionary.
    """
    def __init__(self, model, split_learning=False):
        super().__init__()
        # Flag for whether mask and weights were learned seperately
        self.total_params = self._calc_total_params(model)
        self.sparsity = {}
        
    def _calc_total_params(self, model):
        return {n:p.numel() for n,p in model.named_parameters()}
        
    def update(self, model):
        """
        Update the metric based on the model parameters.
        The model is passed here, and we count the number of non-zero parameters
        as a dictionary of every parameter.
        """
        for n,p in model.named_parameters():
            curr_nonzero = p.count_nonzero()
            total_params = self.total_params[n]
            sparsity = 100 * (1 - curr_nonzero/total_params)
            self.sparsity[n] = sparsity
            
    def compute(self):
        """
        Compute the metric.
        """
        return self.sparsity

    def reset(self):
        """
        Reset the metric at the start of each batch or epoch.
        """
        self.sparsity = {}
        
        
#%%

def get_metrics(metrics, device, **kwargs):
    """
    Create a SmartMetricCollection of the requested list of metric names
    """
    
    if metrics is None:
        return None
    
    model = kwargs.pop('model', None)
    
    collection_dict = {}
    for name in metrics:
        if name == 'accuracy':
            num_classes = kwargs.pop('num_classes', 10)
            collection_dict [name] = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes, average='micro').to(device)
        
        elif name == 'sparsity':
            if not model:
                raise ValueError(f'No "model" argument provided for metrics {name}.')
            collection_dict[name] = ModelSparsityMetric(model)
            
        elif name == 'linear_sparsity':
            collection_dict[name] = LinearSparsityMetric() 
            
        elif name == 'conv_sparsity':
            collection_dict[name] = ConvSparsityMetric() 
            
        elif name == 'layer_sparsity':
            if not model:
                raise ValueError(f'No "model" argument provided for metrics {name}.')
            collection_dict[name] = LayerSparsityMetric(model)
        else:
            print(f'(!) Did not recognize metric "{name}"')
    
    return SmartMetricCollection(collection_dict)
    
        
     

