import torch
import h5py
import numpy as np

from utility.args import Args

class BaseMetric:
    name: str = None
    logTrain: bool = True
    logTest: bool = True
    printTrain: bool = True
    printTest: bool = True
    consolePrintFormat: str = ".4f"
    concatenateWorkers: bool = True #whether to concatenate data from all workers after iteration or just use local_rank==0's data
    def __init__(self):
        self.firstFlush = not Args.contin
        self.buffer = []
        if not self.logTest:
            self.printTest = False
        if not self.logTrain:
            self.printTrain = False

    def getDisplayStr(self) -> float:
        """reduce buffer and generate str for command line output"""
        return f"{np.mean(np.concatenate(self.buffer)):{self.consolePrintFormat}}"

    def _reduceData(self) -> np.ndarray:
        """reduce data to write to buffer
        default: no reduction
        
        return np.array or convertible to np.array which has shape [dataLen, self.shape], where dataLen is concatenated into file
        return values have to have at least 1 dim (i.e. len(...)==1)
        """
        return np.array([np.mean(np.concatenate(self.buffer))])

    def fetchMetric(self, state: dict):
        """
        calls self.calcMetric
        gather data from workers
        """
        worker_metric: np.ndarray = self.calcMetric(state)

        if self.concatenateWorkers:
            gather_list = [None] * torch.distributed.get_world_size()
            torch.distributed.gather_object(
                                    worker_metric,
                                    object_gather_list = gather_list if torch.distributed.get_rank() == 0 else None,
                                    dst = 0
                                    )
            
            if torch.distributed.get_rank() == 0:
                self.buffer.append(np.concatenate(gather_list, axis = -1))
        else:
            if torch.distributed.get_rank() == 0:
                self.buffer.append(worker_metric)

    def calcMetric(self, state: dict) -> np.ndarray:
        """
        calcMetric should calculate metric from 'state', move it to cpu, convert to np.ndarray, and make it ready for reduction by _reduceData
        return: np.array
        """
        raise NotImplementedError

    def flushData(self, file: h5py.File, mode: str):
        if len(self.buffer) == 0:
            raise RuntimeError(f"No data collected for metric {self.name} in set {mode}")

        data = self._reduceData()

        if self.firstFlush:
            # determine data shape
            self.firstFlush = False
            shape = data.shape[1:] if len(data.shape) > 1 else ()
            
            # create datasets
            if self.logTrain:
                file["train"].create_dataset(self.name, shape=(0, *shape), dtype=float, maxshape=(None, *shape), chunks=True)
            if  self.logTest:
                file["test"].create_dataset(self.name, shape=(0, *shape), dtype=float, maxshape=(None, *shape), chunks=True)
            
        dataset = file[mode][self.name]
        dataset.resize(len(dataset)+len(data), axis = 0)
        dataset[len(dataset)-len(data):, ...] = np.array(data)

        self.buffer = []

available_metrics = {}
def addMetric(class_):
    if class_.name is None:
        raise ValueError(f"Metric.name has to be defined in metric class definition for metric {class_}.")
    available_metrics[class_.name] = class_
    return class_


"""
 - to create new metrics, simply define classes here which inherit from BaseMetric and decorate with @addMetric
 - overwrite calcMetric and other needed functions (see definitions above)
"""

@addMetric
class MetricLoss(BaseMetric):
    name = "loss"
    def calcMetric(self, state: dict) -> np.ndarray:
        return state["loss"].cpu().numpy()

@addMetric
class MetricEqualLoss(BaseMetric):
    logTest: bool = False
    name = "equalLoss"
    def calcMetric(self, state: dict) -> np.ndarray:
        return state["equalLoss"].cpu().numpy()

@addMetric
class MetricSubLoss(BaseMetric):
    logTest: bool = False
    name = "subLoss"
    def calcMetric(self, state: dict) -> np.ndarray:
        return state["subLoss"].cpu().numpy()

@addMetric
class MetricLossPerSample(BaseMetric):
    name = "lossPS"
    def calcMetric(self, state: dict) -> np.ndarray:
        return state["loss"].cpu().numpy()
    def _reduceData(self):
        return np.concatenate(self.buffer)

@addMetric
class MetricLR(BaseMetric):
    logTest: bool = False
    concatenateWorkers = False
    name = "learningRate"
    consolePrintFormat = ".3e"
    def calcMetric(self, state: dict) -> np.ndarray:
        return np.array(state["lrScheduler"].get_last_lr())

@addMetric
class MetricAccuracy(BaseMetric):
    name = "accuracy"
    consolePrintFormat = ".2%"
    def calcMetric(self, state: dict) -> np.ndarray:
        return (torch.argmax(state["predictions"].data, 1) == state["targets"]).to(float).cpu().numpy()

@addMetric
class MetricSubAccuracy(BaseMetric):
    logTest: bool = False
    name = "subAccuracy"
    consolePrintFormat = ".2%"
    def calcMetric(self, state: dict) -> np.ndarray:
        return (torch.argmax(state["subPredictions"].data, 1) == state["subTargets"]).to(float).cpu().numpy()

@addMetric
class MetricLayerNormBias(BaseMetric):
    logTest: bool = False
    printTrain: bool = False
    name = "layerNormBias"
    def calcMetric(self, state: dict) -> np.ndarray:
        return state["model"].module.finalLayerNorm.bias.data.detach().cpu().numpy()
    def _reduceData(self):
        return np.stack(self.buffer)

@addMetric
class MetricLayerNormWeight(BaseMetric):
    logTest: bool = False
    printTrain: bool = False
    concatenateWorkers = False
    name = "layerNormWeight"
    def calcMetric(self, state: dict) -> np.ndarray:
        return state["model"].module.finalLayerNorm.weight.data.detach().cpu().numpy()
    def _reduceData(self):
        return np.stack(self.buffer)
