import argparse
from collections import defaultdict
from pathlib import Path
import pickle
from shutil import rmtree
import sys
from typing import Any, Callable, TypeVar

from flax import nnx
from flax.nnx.filterlib import Filter
import numpy as np
from orbax.checkpoint import AsyncCheckpointer, PyTreeCheckpointHandler
from orbax.checkpoint.args import PyTreeRestore, PyTreeSave
import tensorflow as tf
import tomlkit

from offline.types import SummaryWriter


ModelT = TypeVar("ModelT", bound=nnx.Module)


class Logger:
    def __init__(
        self, root: Path | None, leave: bool = False, log_every: int = 1000
    ) -> None:
        self.leave = leave
        self.log_every = log_every
        self.root = root
        self._writer: SummaryWriter | None = None
        self._checkpointer = None
        self.counts: dict[str, int] = defaultdict(int)

    @property
    def checkpointer(self) -> AsyncCheckpointer:
        if self.root is None or not self.root.is_dir():
            raise ValueError(f"{self.root} is not a directory.")
        if self._checkpointer is None:
            self._checkpointer = AsyncCheckpointer(PyTreeCheckpointHandler())
        return self._checkpointer

    @property
    def writer(self) -> SummaryWriter:
        if self._writer is not None:
            return self._writer
        if self.root is None:
            writer: SummaryWriter = tf.summary.create_noop_writer()
        else:
            writer = tf.summary.create_file_writer(str(self.root))
        self._writer = writer
        return writer

    def cleanup(self) -> None:
        if self.root is not None and not self.leave:
            print(f"Removing {str(self.root)}")
            rmtree(self.root)

    def joinpath(self, *paths: str) -> str:
        if self.root is None:
            raise ValueError("Root is None")
        path = self.root.joinpath(*paths)
        path.parent.mkdir(parents=True, exist_ok=True)
        return str(path.absolute())

    def load_args(self) -> dict[str, Any]:
        if self.root is None:
            raise ValueError("Cannot load arguments from an empty logger")
        with (self.root / "arguments.toml").open(encoding="utf-8") as file:
            return dict(tomlkit.load(file))

    def load_numpy(self, *paths: str, **kwargs):
        if self.root is None:
            raise ValueError("Cannot load a NumPy object from an empty logger")
        return np.load(self.joinpath(*paths), **kwargs)

    def load_pickle(self, *paths: str):
        if self.root is None:
            raise ValueError("Cannot load an object from an empty logger")
        with open(self.joinpath(*paths), "rb") as file:
            return pickle.load(file)

    def load_toml(self, *paths: str) -> dict:
        if self.root is None:
            raise ValueError("Cannot load a toml file from an empty logger")
        path = self.joinpath(*paths)
        with open(path, "r", encoding="utf-8") as file:
            return dict(tomlkit.load(file))

    def restore_model(
        self,
        *paths: str,
        model_fn: Callable[[], ModelT],
        poi: Filter | None = None,
    ) -> ModelT:
        abstract_model = nnx.eval_shape(model_fn)
        if poi is None:
            graphdef, state = nnx.split(abstract_model)
            rest = None
        else:
            results = nnx.split(abstract_model, poi, ...)
            assert len(results) == 3
            graphdef, state, rest = results
        state = self.restore_state(*paths, state=state)
        if rest is None:
            model = nnx.merge(graphdef, state)
        else:
            model = nnx.merge(graphdef, state, rest)
        return model

    def restore_state(self, *paths: str, state):
        if self.checkpointer is None:
            raise ValueError("Cannot load a module from an empty logger")
        return self.checkpointer.restore(
            self.joinpath(*paths), args=PyTreeRestore(state)  # type: ignore
        )

    def restore_model_from_path(
        self,
        path: Path,
        model_fn: Callable[[], ModelT],
        poi: Filter | None = None,
    ) -> ModelT:
        abstract_model = nnx.eval_shape(model_fn)
        if poi is None:
            graphdef, state = nnx.split(abstract_model)
            rest = None
        else:
            results = nnx.split(abstract_model, poi, ...)
            assert len(results) == 3
            graphdef, state, rest = results
        state = self.restore_state_from_path(path, state)
        if rest is None:
            model = nnx.merge(graphdef, state)
        else:
            model = nnx.merge(graphdef, state, rest)
        return model

    def restore_state_from_path(self, path: Path, state):
        if self.checkpointer is None:
            raise ValueError("Cannot load a module from an empty logger")
        return self.checkpointer.restore(
            path, args=PyTreeRestore(state)  # type: ignore
        )

    def save_args(self, args: argparse.Namespace) -> str:
        dict_args = vars(args)
        # TOML does not support None
        dict_args = {
            key: "None" if value is None else value
            for key, value in dict_args.items()
        }
        return self.save_toml("arguments.toml", obj=dict_args)

    def save_model(
        self, *paths: str, model: nnx.Module, poi: Filter | None = None
    ) -> str:
        if poi is None:
            state = nnx.state(model)
        else:
            state, _ = nnx.state(model, poi, ...)
        return self.save_state(*paths, state=state)

    def save_numpy(self, *paths: str, **kwargs) -> str:
        if self.root is None:
            return ""
        path = self.joinpath(*paths)
        np.savez(path, **kwargs)
        return path

    def save_pickle(self, *paths: str, obj) -> str:
        if self.root is None:
            return ""
        path = self.joinpath(*paths)
        with open(path, "wb") as file:
            pickle.dump(obj, file)
        return path

    def save_state(self, *paths: str, state) -> str:
        if self.root is None:
            return ""
        path = self.joinpath(*paths)
        self.checkpointer.save(path, args=PyTreeSave(state))  # type: ignore
        return path

    def save_toml(self, *paths: str, obj: dict) -> str:
        if self.root is None:
            tomlkit.dump(obj, sys.stdout)
            return ""
        path = self.joinpath(*paths)
        with open(path, "w", encoding="utf-8") as file:
            tomlkit.dump(obj, file)
        return path

    def wait(self) -> None:
        if self._checkpointer is not None:
            print("Waiting for checkpointer to finish.")
            self._checkpointer.wait_until_finished()

    def write(self, step, /, **kwargs):
        to_write = {}
        for key, value in kwargs.items():
            self.counts[key] += 1
            if self.counts[key] == self.log_every:
                to_write[key] = value
                self.counts[key] = 0
        if not to_write:
            return
        if self.root is None:
            print(f"Step {step}")
            print(to_write)
        else:
            with self.writer.as_default(step=step):
                for key, value in to_write.items():
                    tf.summary.scalar(key, value)


class ChildLogger(Logger):
    def __init__(
        self,
        root: Path | None,
        parent: Path,
        log_every: int = 1000,
        leave: bool = False,
    ) -> None:
        super().__init__(root=root, leave=leave, log_every=log_every)
        self.parent = Logger(root=parent, log_every=log_every)
