from loguru import logger
import pandas as pd
import numpy as np

from os import path, makedirs, PathLike

from ast import literal_eval
from collections.abc import MutableMapping

from molecule_movement.utils import SliceableDeque
from collections import Counter

from typing import Optional

def flatten(dictionary, parent_key='', separator='_'):
    items = []
    for key, value in dictionary.items():
        new_key = parent_key + separator + key if parent_key else key
        if isinstance(value, MutableMapping):
            items.extend(flatten(value, new_key, separator=separator).items())
        else:
            items.append((new_key, value))
    return dict(items)

class StatisticsLogger:
    _instance = None

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super(StatisticsLogger, cls).__new__(cls)
            cls._instance.__initialized = False
        return cls._instance

    def __init__(self,
                 logging_tasks: list[str] = ["stats"],
                 log_trace_only: bool = True,
                 enforce_equal_column_lengths: bool = False,
                 filename: Optional[PathLike] = None,
                 maxlen: int = None):
        if(self.__initialized): return
        self.logging_tasks = logging_tasks
        self.__initialized = True
        self.log_dict = dict()
        self.maxlen = maxlen

        self.__column_counter = Counter({})
        self.__keys = set()
        self.__enforce_equal_column_lengths = enforce_equal_column_lengths

        if maxlen:
            logger.info(f"Initialized StatisticsLogger with {maxlen=}")

        if log_trace_only:
            def logging_filter(record):
                return record["level"].name == "TRACE" and record["extra"].get("task") in self.logging_tasks
        else:
            def logging_filter(record):
                return record["extra"].get("task") in self.logging_tasks
        logger.add(lambda record : self.store(record),
                   filter=lambda record: logging_filter(record),
                   format="{extra}",
                   level="TRACE")

        if filename is not None:
            logger.add(filename,
                       filter=lambda record: logging_filter(record),
                       format=lambda record: self.pretty_print(record),
                       level="TRACE")

    def pretty_print(self, record) -> str:
        try:
            stats = record["extra"]
            stats.pop("task")
            pretty_str = str(flatten(stats)).replace("'","").replace("{","").replace("}","")
            return f"{pretty_str}\n"
        except KeyError as e:
            return ""

    def store(self, record) -> None:
        try:
            parsed_dict = flatten(literal_eval(record.replace("\n","")))
            parsed_dict.pop("task")

            for k, v in parsed_dict.items():
                if k == "newline": continue
                try:
                    self.log_dict[k].append(v)
                except KeyError:
                    self.log_dict[k] = SliceableDeque([v], maxlen=self.maxlen)
            self.__keys |= set(parsed_dict)
            if "newline" in parsed_dict:
                if parsed_dict["newline"]:
                    _, longest_column_length = self.__longest_column()
                    for key in self.__keys:
                        if key == "newline": continue
                        if len(self.log_dict[key]) == longest_column_length: continue
                        difference = longest_column_length - len(self.log_dict[key])
                        self.log_dict[key].extend([None] * difference)
                        self.__column_counter[key] += difference
                parsed_dict.pop("newline")
            self.__column_counter.update(parsed_dict.keys())
            if self.__enforce_equal_column_lengths: self.__check_column_lengths()

        except KeyError as e:
            #logger.warning(f"Could not log: KeyError: {e}")
            pass

    @property
    def df(self) -> pd.DataFrame:
        try:
            self._df = pd.DataFrame(self.log_dict, columns=self.log_dict.keys())
        except ValueError as e:
            logger.error(f"Cannot produce df for invalid data: {self.log_dict}")
            raise e
        return self._df

    def df_from_columns(self, columns: list) -> pd.DataFrame:
        return pd.DataFrame(self.log_dict, columns=columns)

    def df_to_csv(self, filename: str):
        directory = path.dirname(filename)
        logger.info(f"Saving csv of statistics to {filename}")
        if directory:
            makedirs(directory, exist_ok=True)
        return self.df.to_csv(filename, sep=',', index=False, encoding='utf-8', mode="w")

    def dump(self, filename: str) -> None:
        with open(filename, 'a') as f:
            f.write(self.df.to_string(header=True, index=True))

    def clear(self) -> None:
        self.log_dict = dict()
        self.__column_counter.clear()

    def filter_none(self, field: str):
        array = np.array(self.log_dict[field])
        return array[array != None]

    def sum(self, field: str) -> float:
        return float(np.sum(self.filter_none(field)))

    def mean(self, field: str) -> float:
        return np.mean(self.filter_none(field))

    def max(self, field: str) -> float:
        return np.max(self.filter_none(field))

    def min(self, field: str) -> float:
        return np.min(self.filter_none(field))

    def std(self, field: str) -> float:
        return np.std(self.filter_none(field))

    def stats(self, field: str) -> tuple:
        return (self.min(field), self.mean(field), self.max(field), self.std(field))

    def last_N_mean(self, field: str, N: int) -> float:
        return np.mean(self.filter_none(field)[-N:])

    def last_N_max(self, field: str, N: int) -> float:
        return np.max(self.filter_none(field)[-N:])

    def last_N_min(self, field: str, N: int) -> float:
        return np.min(self.filter_none(field)[-N:])

    def last_N_std(self, field: str, N: int) -> float:
        return np.std(self.filter_none(field)[-N:])

    def last_N_stats(self, field: str, N: int) -> tuple:
        return (self.last_N_min(field, N), self.last_N_mean(field, N), self.last_N_max(field, N), self.last_N_std(field, N))

    def __check_column_lengths(self) -> None:
        if np.abs(self.__column_counter.most_common()[0][1] - self.__column_counter.most_common()[-1][1]) > 1:
            msg = ""
            for key in self.log_dict.keys():
                msg += f"\t{key} -> {len(self.log_dict[key])=}\n"
            raise ValueError(f"The length of the logged columns differ:\n{msg}")

    def __longest_column(self) -> str:
        return self.__column_counter.most_common(1)[0]

def enable_statistics_logger(logging_tasks: list[str] = ["stats"], enforce_equal_column_lengths: bool = False, log_trace_only: bool = True, filename: Optional[PathLike] = None, maxlen: int = None) -> StatisticsLogger:
    return StatisticsLogger(logging_tasks=logging_tasks, enforce_equal_column_lengths=enforce_equal_column_lengths, log_trace_only=log_trace_only, filename=filename, maxlen=maxlen)

def dump_statistics(filename):
    statistics = StatisticsLogger()
    statistics.dump(filename)
