import argparse
import functools
import os
import sys
import time
from os import path

import dataclasses

import torch
import typing
from YYY import TypeHandler

import YYY

# KEEP THIS AROUND TO SET THE DIRECTORY AUTOMAGICALLY
import XXX.notebook


def get_likely_github_commit_url(github_url: str, commit):
    "git@github.com:user/repo.git"
    "https://github.com/user/repo.git"

    git_prefix = "git@github.com:"
    if github_url.startswith(git_prefix):
        github_url = f"https://github.com/{github_url[len(git_prefix)]:}"
    https_prefix = "https://github.com/"
    if not github_url.startswith(https_prefix):
        return commit

    suffix = ".git"
    if github_url.endswith(suffix):
        github_url = f"{github_url[:-len(suffix)]}"
    else:
        return commit

    return f"{github_url}/commit/{commit}"


def get_git_head_commit_and_url(path):
    try:
        import git
    except ImportError:
        return "", ""

    try:
        repo = git.Repo(path, search_parent_directories=True)
        commit = str(repo.commit())
        urls = [get_likely_github_commit_url(url, commit) for url in repo.remote("origin").urls] or [commit]
        return commit, urls[0]
    except git.exc.InvalidGitRepositoryError:
        return "", ""


def dummy_add_experiment_config_args(parser):
    return parser


def parse_args_and_load_experiment_config(
    add_experiment_config_args=dummy_add_experiment_config_args, exposed_symbols=tuple()
):
    parser = argparse.ArgumentParser(
        description="Unpacking Information Bottlenecks Experiment",
        formatter_class=functools.partial(argparse.ArgumentDefaultsHelpFormatter, width=120),
    )
    parser.add_argument("--experiment_task_id", type=str, default=None, help="experiment id")
    parser.add_argument(
        "--experiments_YYY", type=str, default=None, help="YYY file that contains all experiment task configs"
    )
    parser.add_argument(
        "--result_dir", type=str, default=None, help="YYY file that contains all experiment task configs"
    )
    parser.add_argument(
        "--experiment_description", type=str, default="Trying stuff..", help="Description of the experiment"
    )
    parser = add_experiment_config_args(parser)
    args = parser.parse_args()

    if args.experiments_YYY is not None:
        config = YYY.safe_load(args.experiments_YYY, exposed_symbols=exposed_symbols)
        # Merge the experiment config with args.
        # Args take priority.
        args = parser.parse_args(namespace=argparse.Namespace(**config[args.experiment_task_id]))

    # DONT TRUNCATE LOG FILES EVER AGAIN!!! (OFC THIS HAD TO HAPPEN AND BE PAINFUL)
    if args.experiment_task_id:
        store_name = f"task_{args.experiment_task_id}_result"
    else:
        store_name = f"unnamed_result"

    if args.experiments_YYY:
        result_dir = path.dirname(args.experiments_YYY) + "/results/"
    else:
        result_dir = "./experiments/runs/results"

    store = create_experiment_store(result_dir, store_name)
    store["args"] = args.__dict__
    print("Parsed args:")
    print(args.__dict__)

    return args, store


def embedded_experiment(script_file):
    script_file = path.abspath(script_file)
    result_dir = path.dirname(script_file) + "/YYY/"
    result_name = path.splitext(path.basename(script_file))[0] + "_result"
    store = create_experiment_store(result_dir=result_dir, store_name=result_name)
    store["experiment"] = script_file
    return store


def embedded_experiments(script_file, num_jobs):
    parser = argparse.ArgumentParser(
        description="Unpacking Information Bottlenecks Experiment",
        formatter_class=functools.partial(argparse.ArgumentDefaultsHelpFormatter, width=120),
    )
    parser.add_argument("--id", type=int, default=None, help="experiment id")
    parser.add_argument(
        "--num_workers", type=int, default=None, help="number of worker (None means one worker per job)"
    )
    args = parser.parse_args()

    num_workers = args.num_workers or num_jobs
    worker_id = args.id
    job_id = args.id
    if job_id is None or not (0 <= job_id < num_jobs):
        raise ValueError(f"0 <= --id={job_id} < {num_jobs}!")

    # DONT TRUNCATE LOG FILES EVER AGAIN!!! (OFC THIS HAD TO HAPPEN AND BE PAINFUL)
    script_file = path.abspath(script_file)
    result_dir = path.dirname(script_file) + "/results/"

    while job_id < num_jobs:
        result_name = path.splitext(path.basename(script_file))[0] + f"_job_{job_id}"
        store = create_experiment_store(result_dir=result_dir, store_name=result_name)

        store["experiment"] = script_file
        store["job_id"] = job_id
        store["worker_id"] = worker_id
        store["num_workers"] = num_workers

        print(f"Task id: {job_id}/{num_jobs}")
        print(f"Worker id: {worker_id}/{num_workers}")

        yield job_id, store

        store.close()

        job_id += num_workers


class DataclassesHandler(TypeHandler):
    """Requires custom handling on safe_load."""

    def supports(self, obj):
        return dataclasses.is_dataclass(obj)

    def wrap(self, obj, wrap):
        return obj

    def repr(self, obj, repr, store):
        return repr({key: value for key, value in dataclasses.asdict(obj).items()})


class TensorHandler(TypeHandler):
    def supports(self, obj):
        return isinstance(obj, torch.Tensor)

    def wrap(self, obj, wrap):
        return obj.tolist()

    def repr(self, obj, repr, store):
        # This will never be called.
        return repr(obj.tolist())


class FunctionHandler(TypeHandler):
    def supports(self, obj):
        return isinstance(obj, typing.Callable)

    def wrap(self, obj, wrap):
        return obj

    def repr(self, obj, repr, store):
        return repr(f"{obj.__module__}.{obj.__qualname__}")


def create_experiment_store(*, result_dir, store_name=None):
    # Make sure we have a directory to store the results in, and we don't crash!
    os.makedirs(result_dir, exist_ok=True)
    safe_store_name = store_name if store_name else "results"
    store = YYY.create_file_store(
        safe_store_name,
        prefix=result_dir,
        truncate=False,
        type_handlers=(
            TensorHandler(),
            DataclassesHandler(),
            YYY.StrEnumHandler(),
            FunctionHandler(),
            YYY.ToReprHandler(),
        ),
    )

    store["timestamp"] = int(time.time())
    store["cmdline"] = sys.argv[:]
    commit, github_url = get_git_head_commit_and_url(".")
    store["commit"] = commit
    store["github_url"] = github_url
    print("Command line:")
    print("|".join(sys.argv))
    print(f"GitHub URL: {github_url}")
    print(f"Commit: {commit}")
    # TODO: this is still missing the proper filename!
    print(f"Results stored in {store.uri}")
    return store


def create_log_epochs(store, prefix):
    name = f"{prefix}_epochs"
    store[name] = []
    epochs = store[name]
    return epochs


if __name__ == "__main__":
    parse_args_and_load_experiment_config()
