import glob
import inspect
import os
import site
import torch
import datetime
import dateutil.tz
import json
import rlkit
from rlkit import conf
from rlkit.core.logging import SEPARATOR, bold, green, red
from rlkit.core.pythonplusplus import dict_to_safe_json
from rlkit.entrypoint.SUPPORTED_ALGORITHMS import ALGORITHMS, AlgorithmDescription
import argparse
import importlib
import importlib.util
from pydantic import BaseSettings
from rlkit.variants.base import FuncWrapper
from rlkit.conf import GridSearch, Parallel

from rlkit.launchers.launcher_util import (
    run_hyperparameters,
    run_parallel_pipeline_here,
    run_pipeline_here,
)


def find_attr(module, attr_substr):
    return [attr for attr in user_defined_attrs(module) if attr_substr in attr]


def user_defined_attrs(cls):
    return [attr for attr in dir(cls) if not attr.startswith("__")]


def user_defined_attrs_dict(cls):
    return {k: v for k, v in cls.__dict__.items() if not k.startswith("__")}


def load_experiment(args, alg: AlgorithmDescription):
    if alg.online:
        raise NotImplementedError

    variant_module = importlib.import_module(f"rlkit.variants.variant_{alg.name}")
    variants = find_attr(variant_module, "Variant")
    if args.variant is None:
        raise Exception("Did not specify variant name")
    if args.variant not in variants:
        raise Exception(f"Could not find the specified variant in {variants}")

    variant: BaseSettings = getattr(variant_module, args.variant)

    variant_dict = variant().dict()
    for k in variant_dict:
        if isinstance(variant_dict[k], FuncWrapper):
            variant_dict[k] = variant_dict[k].f
        if isinstance(variant_dict[k], dict):
            for k2 in variant_dict[k]:
                if isinstance(variant_dict[k][k2], FuncWrapper):
                    variant_dict[k][k2] = variant_dict[k][k2].f

    return variant_dict


def print_section(name, content):
    bold(name.upper() + ":", "\n")
    print(content, SEPARATOR)


def main(args):

    for l in glob.glob(f"{site.getsitepackages()[0]}/mujoco_py/generated/*lock"):
        print(l)
        os.remove(l)

    alg: AlgorithmDescription = None
    for a in ALGORITHMS:
        if a.name == args.algorithm:
            alg = a
            break
    if alg is None:
        raise Exception("This algorithm is not supported")

    variant = load_experiment(args, alg)

    if args.dry:
        now = datetime.datetime.now(dateutil.tz.tzlocal())
        timestamp = now.strftime("%Y-%m-%d %H:%M:%S.%f %Z")
        print_section("time", timestamp)
        pipeline = variant["pipeline"]
        print_section("variant", json.dumps(dict_to_safe_json(variant), indent=2))
        print_section("pipeline", pipeline.composition)

    if args.gridsearch:
        if args.parallel not in user_defined_attrs(Parallel):
            raise Exception(f"Parallel argument {args.parallel} not supported")

        if args.gridsearch not in user_defined_attrs(GridSearch):
            raise Exception(f"Gridsearch argument {args.gridsearch} not supported")

        if args.dry:
            print_section(
                "gridsearch args",
                inspect.getsource(getattr(GridSearch, args.gridsearch)),
            )
        else:
            run_hyperparameters(
                getattr(Parallel, args.parallel),
                variant,
                hyperparameters=user_defined_attrs_dict(
                    getattr(GridSearch, args.gridsearch)
                ),
            )
            return

    if args.parallel:
        if args.parallel not in user_defined_attrs(Parallel):
            raise Exception(f"Parallel argument {args.parallel} not supported")

        if args.dry:
            print_section(
                "parallel args",
                inspect.getsource(getattr(Parallel, args.parallel)),
            )
        else:
            run_parallel_pipeline_here(getattr(Parallel, args.parallel), variant)
            return

    if args.dry:
        red("Debug mode: ", conf.DEBUG)
        red("Root dir", conf.Log.rootdir)
        return

    run_pipeline_here(
        variant=variant,
        snapshot_mode=variant.get("snapshot_mode", "gap_and_last"),
        snapshot_gap=variant.get("snapshot_gap", 100),
        gpu_id=variant.get("gpu_id", 0),
    )


if __name__ == "__main__":
    torch.multiprocessing.set_start_method("spawn")
    """This is executed when run from the command line"""
    parser = argparse.ArgumentParser()

    parser.add_argument("algorithm", help="Specify algorithm to run")

    parser.add_argument(
        "--variant", default=None, help="Specify which variant of the algorithm to run"
    )
    parser.add_argument(
        "--gridsearch",
        default=None,
        help="Do we want to do a gridsearch?",
    )
    parser.add_argument(
        "--dry",
        action="store_true",
        default=False,
        help="Do we want to just print the variant and pipeline??",
    )

    parser.add_argument(
        "--parallel",
        default=None,
        help="""
        Run multiple versions of the algorithm on different environments and seeds. Use 'small', 'medium' or 'full'         
        to specify how many envs you want to run.
        """,
    )

    parser.add_argument(
        "-v", "--verbose", action="count", default=0, help="Verbosity (-v, -vv, etc)"
    )

    parser.add_argument(
        "--version",
        action="version",
        version="%(prog)s (version {version})".format(version=rlkit.__version__),
    )

    args = parser.parse_args()
    main(args)
