import sys
import os

sys.path.append(
    os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
)
from megatron.neox_arguments import neox_args, deepspeed_args
from inspect import getmembers, getsource
from dataclasses import field, is_dataclass
from itertools import tee, zip_longest
import pathlib


def pairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    a, b = tee(iterable)
    next(b, None)
    return zip_longest(a, b)


def get_docs(module):
    ARGS_CLASSES = getmembers(module, is_dataclass)
    results = {}
    for name, dcls in ARGS_CLASSES:
        assert is_dataclass(dcls)
        src = getsource(dcls)
        d = dcls()
        loc = 0
        results[name] = {"doc": d.__doc__.strip(), "attributes": {}}
        for cur, _next in pairwise(d.__dataclass_fields__.items()):
            field_name, field_def = cur
            field_type = field_def.type
            if hasattr(field_type, "__name__"):
                field_type = field_type.__name__
            else:
                field_type = str(field_type)

            field_default = field_def.default

            # try to find the field definition
            loc = src.find(f" {field_name}:", loc + len(field_name) + 1)

            if _next is not None:
                next_field_name, _ = _next
                # try to find the next field definition
                next_loc = src.find(f"{next_field_name}:", loc + len(field_name))
            else:
                next_loc = len(src)

            # try to get the docstring
            _src = src[loc:next_loc].strip()
            if '"""' in _src:
                doc = _src.split('"""')[1].strip()
            elif "'''" in _src:
                doc = _src.split("'''")[1].strip()
            else:
                doc = ""
            results[name]["attributes"][field_name] = {
                "name": field_name,
                "type": field_type,
                "default": field_default,
                "doc": doc,
            }
    return results


def to_md(docs, intro_str=""):
    """
    Writes the docs dictionary to markdown format
    """
    lines = []
    lines.append(intro_str)
    for name, doc in docs.items():
        lines.append(f"## {name}")
        lines.append(f"{doc['doc']}")
        lines.append("")
        for field_name, field_def in doc["attributes"].items():
            # attribute name and type
            lines.append(f"- **{field_name}**: {field_def['type']}")
            # default value
            lines.append(f"    Default = {str(field_def['default'])}")
            lines.append(f"    {field_def['doc']}")
            lines.append("")
    return "\n\n".join(lines)


if __name__ == "__main__":
    docs = get_docs(neox_args)
    docs.update(get_docs(deepspeed_args))
    intro_str = """Arguments for gpt-neox. All of the following can be specified in your .yml config file(s):\n"""
    md = to_md(docs, intro_str=intro_str)
    with open(f"{pathlib.Path(__file__).parent.resolve()}/neox_arguments.md", "w") as f:
        f.write(md)
