import os
import anymarkup

try:
    from tensorboardX import SummaryWriter
except ImportError:
    SummaryWriter = None

import exp_utils as PQ
from .flags import MetaFLAGS, BaseFLAGS, set_value

from rich.traceback import install
install()


def set_random_seed(seed: int):
    import random
    import numpy as np

    random.seed(seed)
    np.random.seed(seed)

    try:
        import tensorflow as tf
        if tf._major_api_version < 2:
            tf.set_random_seed(seed)
        else:
            tf.random.set_seed(seed)
    except ImportError:
        tf = None
        pass

    try:
        import torch
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
    except ImportError:
        pass


def parse_string(s: str):
    import ast
    try:
        return ast.literal_eval(s)
    except (ValueError, SyntaxError):
        return s


def setup_log_dir(log_dir):
    PQ.fs.init(log_dir)

    PQ.log.add(PQ.log_dir / "out.log", level='DEBUG', compression="zip")
    PQ.writer = SummaryWriter(logdir=str(PQ.log_dir))
    PQ.log.warning(f'log_dir = {str(PQ.log_dir)}')


def set_breakpoint():
    try:
        import ipdb

        env_var = 'PYTHONBREAKPOINT'
        if env_var in os.environ:
            PQ.log.critical(f'skip patching `ipdb`: environment variable `{env_var}` has been set.')
        else:
            os.environ[env_var] = 'ipdb.set_trace'
    except ImportError:
        PQ.log.critical(f'skip patching `ipdb`: `ipdb` not found.')


def init(root: MetaFLAGS):
    if 'seed' not in root:
        root.add('seed', int.from_bytes(os.urandom(3), 'little'))

    args = parse(root)
    log_dir = args.log_dir
    root.freeze()

    seed = root.seed
    set_random_seed(seed)
    set_breakpoint()

    if log_dir is not None:
        setup_log_dir(log_dir)
        dump_flags(root)
    else:
        PQ.log.critical('no log_dir provided')

    if args.print_config:
        serialized = anymarkup.serialize(root.to_dict(), "json5").decode('utf-8')
        PQ.log.info(f'FLAGS: {serialized}')


def dump_flags(root: MetaFLAGS):
    import sys

    anymarkup.serialize_file(root.to_dict(), PQ.log_dir / 'config.json5')
    PQ._meta['log_dir'] = str(PQ.fs.log_dir)
    PQ._meta['argv'] = sys.argv


def parse(root: MetaFLAGS):
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('-p', '--param', help='additional param', nargs='*', action='append')
    parser.add_argument('--print_config', help='print configs', action='store_true')
    parser.add_argument('--log_dir', help='the directory to logs', default='/tmp')
    root_keys = root.add_to_parser(parser)

    args = parser.parse_args()
    args.print_config = True
    if args.param:
        for cmd in sum(args.param, []):
            cmd: str
            if '=' in cmd:
                path, value = cmd.split('=', maxsplit=1)
                value = parse_string(value)
            else:
                path = '_load'
                if cmd.endswith('.json5'):
                    value = cmd
                else:
                    value = f'configs/{cmd}.json5'
            set_value(root, path, value)

    args_dict = vars(args)
    for key in root_keys:
        if args_dict[key] is not None:
            # don't need to parse, as argparse has done it.
            set_value(root, key, args_dict[key])

    return args


def close():
    PQ.writer.close()
    anymarkup.serialize_file(PQ._meta, PQ.log_dir / 'meta.json5')


def main(root: MetaFLAGS = None):
    import warnings
    warnings.warn("`PQ.main` is being deprecated", DeprecationWarning)

    if root is None:
        class root(BaseFLAGS):
            _strict = False

    def decorate(fn):
        def decorated():
            init(root)
            try:
                fn()
            except Exception:
                close()
            else:
                close()
        return decorated
    return decorate
