from typing import Dict, Any
import copy
import pandas as pd


EXP_FORMAT = '{key} = {value:.2e}'
FLOAT_FORMAT = '{key} = {value:.4f}'

def print_format(results: Dict[str, Any], separator: str = None, use_float=True):
    fmt = FLOAT_FORMAT if use_float else EXP_FORMAT
    separator = separator or ' | '
    msg = ''
    for key, value in results.items():
        if len(msg) > 0:
            msg += separator
        msg += fmt.format(key=key, value=value)
    return msg


class Accumulator:

    def __init__(self, **kwargs):
        self.state = {}
        self.state.update(copy.deepcopy(kwargs))
        self.reset()

    def reset(self):
        raise NotImplementedError

    def update(self, *args, **kwargs):
        raise NotImplementedError

    def compute(self):
        raise NotImplementedError


class SumAccumulator(Accumulator):

    def __init__(self, ignore_zero=False):
        self.ignore_zero = ignore_zero
        super().__init__()
        
    def reset(self):
        self.state.update({'step': 0, 'sum': 0.0})
        if self.ignore_zero:
            self.state['all_steps'] = 0

    def update(self, value):
        v = value if value.ndim == 0 else value.sum()
        if self.ignore_zero:
            not_zero = (value != 0.0)
            n = int(not_zero) if value.ndim == 0 else int(not_zero.sum())
            self.state['all_steps'] += (1 if value.ndim == 0 else len(value))
        else:
            n = 1 if value.ndim == 0 else len(value)
        self.state['sum'] += v
        self.state['step'] += n

    def compute(self):
        return self.state['sum']
    
class MeanAccumulator(SumAccumulator):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def compute(self):
        return self.state['sum'] / self.state['step']

class MixedAccumulator:

    def __init__(self, accumulators: Dict[str, Accumulator]):
        self.accum_map: Dict[str, Accumulator] = accumulators

    def reset(self):
        for accum in self.accum_map.values():
            accum.reset()

    def step(self, key: str, value: Any):
        self.accum_map[key].update(value)
        result = self.accum_map[key].compute()
        return result

    def accumulate(self, values: Dict[str, Any]):
        results = {}
        for key, value in values.items():
            if key not in self.accum_map:
                continue
            results[key] = self.step(key, value)
        return results


class MetricsTracer:

    def __init__(self, metrics_keys=None, index_key='epoch', group_key=None):

        self.metrics_keys = dict.fromkeys(metrics_keys)
        self.index_key = index_key
        self.group_key = group_key
        self.cache_memory = {}
        if group_key is not None:
            self.cache_memory = {group_key: []}
        self.cache_memory.update({index_key: []})
        self.cache_memory.update({key: [] for key in self.metrics_keys})
        
    def reset_cache_memory(self):
        self.cache_memory.clear()

    def trace(self, outputs, use_cache=False, index=None, group=None):
        if self.metrics_keys is None:
            return {}

        metrics = {}

        for key in self.metrics_keys:
            if key not in outputs:
                continue
            metrics[key] = outputs[key].detach().cpu().numpy()

            if metrics[key].ndim != 0:
                metrics[key] = metrics[key].mean()

        if use_cache:
            self.cache_memory[self.index_key].append(index)
            if self.group_key is not None:
                self.cache_memory[self.group_key].append(group)
            for key, value in metrics.items():
                self.cache_memory[key].append(float(value))
            for key in (self.metrics_keys - metrics.keys()):
                self.cache_memory[key].append(None)
        return metrics

    def get_logs(self, outputs, format_separator: str = None, prefix: str = None, 
                 use_cache=False, index=None, group=None):
        metrics = self.trace(outputs, use_cache=use_cache, index=index, group=group)
        separator = format_separator or ' | '
        prefix = prefix or ''
        msg = print_format(metrics, separator)
        if len(prefix) > 0 and len(msg) > 0:
            return prefix + separator + msg
        return msg or prefix

    def export(self, path):
        pd.DataFrame(self.cache_memory).to_csv(path, encoding='utf-8', index=False)
