import argparse
from collections.abc import Sequence
from dataclasses import astuple, dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Final

import jax
import tomlkit

import __main__
from offline.utils.git import check_dirty_and_get_sha
from offline.utils.logger import ChildLogger, Logger
from offline.utils.misc import assert_argument_validity


@dataclass(frozen=True)
class ReservedKeywords:
    # pylint: disable=invalid-name
    CONFIG: Final[str] = field(default="config", init=False)
    DIRTY: Final[str] = field(default="dirty", init=False)
    GPU: Final[str] = field(default="gpu", init=False)
    LEAVE: Final[str] = field(default="leave", init=False)
    LOG: Final[str] = field(default="log", init=False)
    LOG_EVERY: Final[str] = field(default="log_every", init=False)
    MAIN: Final[str] = field(default="main", init=False)
    ROOT: Final[str] = field(default="root", init=False)
    PARENT: Final[str] = field(default="parent", init=False)
    SHA: Final[str] = field(default="sha", init=False)
    TIMESTAMP: Final[str] = field(default="timestamp", init=False)


DEFAULT_ROOT: Final[str] = "runs"
RESERVED_KEYWORDS = ReservedKeywords()
RESERVED_KEYWORDS_SET = frozenset(astuple(RESERVED_KEYWORDS))
NOT_TO_SAVE = (
    RESERVED_KEYWORDS.CONFIG,
    RESERVED_KEYWORDS.LEAVE,
    RESERVED_KEYWORDS.LOG,
    RESERVED_KEYWORDS.LOG_EVERY,
    RESERVED_KEYWORDS.ROOT,
)


class ArgumentParser:
    def __init__(
        self,
        default_root: str | Path = DEFAULT_ROOT,
        extra: bool = False,
        fix_keys: Sequence[str] | None = None,
        log: bool = True,
        use_jax: bool = True,
        **kwargs,
    ):
        self.default_root = Path(default_root)
        self.extra = extra
        self.log = log
        self.parser = argparse.ArgumentParser(**kwargs)
        self.use_jax = use_jax
        self.fix_keys = [] if fix_keys is None else fix_keys
        self.add_defaults()

    def __getattr__(self, name):
        return getattr(self.parser, name)

    def add_argument(self, *args, **kwargs):
        action = self.parser.add_argument(*args, **kwargs)
        if action.dest in RESERVED_KEYWORDS_SET:
            raise argparse.ArgumentError(
                action, f"{action.dest} is a reserved keyword"
            )
        return action

    def add_defaults(self):
        self.parser.add_argument(
            "--config", default="", dest=RESERVED_KEYWORDS.CONFIG
        )
        self.parser.add_argument(
            "--leave", action="store_true", dest=RESERVED_KEYWORDS.LEAVE
        )
        self.parser.add_argument(
            "--log-every", default=1000, dest=RESERVED_KEYWORDS.LOG_EVERY
        )
        self.parser.add_argument(
            "--root", nargs="*", dest=RESERVED_KEYWORDS.ROOT
        )
        if self.extra:
            self.parser.add_argument(
                "--parent", dest=RESERVED_KEYWORDS.PARENT, required=True
            )
        if self.use_jax:
            group = self.parser.add_mutually_exclusive_group()
            group.add_argument(
                "--gpu", type=int, default=0, dest=RESERVED_KEYWORDS.GPU
            )
            group.add_argument(
                "--cpu",
                action="store_const",
                dest=RESERVED_KEYWORDS.GPU,
                const=-1,
            )
        if self.log:
            self.parser.add_argument(
                "-s",
                "--no-log",
                action="store_false",
                dest=RESERVED_KEYWORDS.LOG,
            )

    def parse_args(self, args=None, namespace=None):
        arguments = self.parser.parse_args(args, namespace)
        now = datetime.now()
        setattr(arguments, RESERVED_KEYWORDS.TIMESTAMP, now.isoformat())
        if self.extra:
            parent = Path(getattr(arguments, RESERVED_KEYWORDS.PARENT))
            parent_config = parent / "arguments.toml"
            if not parent_config.exists():
                raise ValueError(f"Not a valid path: {parent}")
            with Path(parent_config).open(encoding="utf-8") as file:
                defaults = tomlkit.load(file)
            self.parser.set_defaults(**defaults)
        else:
            parent = None
        config = getattr(arguments, RESERVED_KEYWORDS.CONFIG)
        if config:
            with Path(config).open(encoding="utf-8") as file:
                defaults = tomlkit.load(file)
            self.parser.set_defaults(**defaults)
        arguments = self.parser.parse_args(args, namespace)
        if self.log and getattr(arguments, RESERVED_KEYWORDS.LOG):
            dirty, sha = check_dirty_and_get_sha()
            setattr(arguments, RESERVED_KEYWORDS.DIRTY, dirty)
            setattr(arguments, RESERVED_KEYWORDS.SHA, sha)
            root_list = getattr(arguments, RESERVED_KEYWORDS.ROOT)
            if root_list:
                root = Path(*root_list)
                root.mkdir(parents=True)
            else:
                if parent is None:
                    root = Path(
                        self.default_root, str(now.date()), str(now.time())
                    )
                    root.mkdir(parents=True)
                else:
                    root = parent / now.isoformat()
                    root.mkdir()
        else:
            root = None
        if root is not None:
            print("Logging to", root)
        leave = getattr(arguments, RESERVED_KEYWORDS.LEAVE)
        log_every = getattr(arguments, RESERVED_KEYWORDS.LOG_EVERY)
        if parent is None:
            logger = Logger(root=root, leave=leave, log_every=log_every)
        else:
            logger = ChildLogger(
                root=root, parent=parent, leave=leave, log_every=log_every
            )
            if self.fix_keys:
                parent_args = logger.parent.load_args()
                for key in self.fix_keys:
                    assert_argument_validity(arguments, key, parent_args[key])
        main_spec = __main__.__spec__
        setattr(
            arguments,
            RESERVED_KEYWORDS.MAIN,
            "None" if main_spec is None else main_spec.name,
        )
        for keywords in NOT_TO_SAVE:
            delattr(arguments, keywords)
        logger.save_args(arguments)
        arguments.logger = logger
        if self.use_jax:
            if jax.devices()[0].platform != "cpu":
                gpu = getattr(arguments, RESERVED_KEYWORDS.GPU)
                if gpu == -1:
                    device = jax.devices("cpu")
                else:
                    device = jax.devices("gpu")[gpu]
                jax.config.update("jax_default_device", device)
        for keywords in RESERVED_KEYWORDS_SET:
            try:
                delattr(arguments, keywords)
            except AttributeError:
                pass
        return arguments
