import argparse
import os

from sklearn.metrics.pairwise import KERNEL_PARAMS

import sys

sys.path.append("..")
sys.path.append(".")
from data import all_datasets


def get_args():
    parser = argparse.ArgumentParser()

    common_parent_parser = argparse.ArgumentParser(add_help=False)
    add_common_args(common_parent_parser)

    model_subparser = parser.add_subparsers(title="method", dest="method")
    kauri_parser = model_subparser.add_parser("kauri", parents=[common_parent_parser])
    add_kauri_args(kauri_parser)

    douglas_parser = model_subparser.add_parser("douglas", parents=[common_parent_parser])
    add_douglas_args(douglas_parser)

    imm_parser = model_subparser.add_parser("imm", parents=[common_parent_parser])
    imm_parser.add_argument("--max_leaves", type=int, default=None)

    exshallow_parser = model_subparser.add_parser("exshallow", parents=[common_parent_parser])
    exshallow_parser.add_argument("--max_leaves", type=int, default=None)

    exkmc_parser = model_subparser.add_parser("exkmc", parents=[common_parent_parser])
    exkmc_parser.add_argument("--max_leaves", type=int, default=None)

    ktree_parser = model_subparser.add_parser("ktree", parents=[common_parent_parser])
    ktree_parser.add_argument("--max_depth", type=int, default=None)
    ktree_parser.add_argument("--max_leaves", type=int, default=None)

    rdm_parser = model_subparser.add_parser("rdm", parents=[common_parent_parser])
    rdm_parser.add_argument("--max_leaves", type=int, default=None)

    args = parser.parse_args()
    check_args(args)

    return args


def add_common_args(parser: argparse.ArgumentParser):
    data_parser = parser.add_argument_group("DATA")
    data_parser.add_argument("--dataset", type=str, choices=all_datasets, required=True)
    data_parser.add_argument("--n_clusters", type=int, default=-1)
    data_parser.add_argument("--subset_size", type=float, default=0.8)

    io_parser = parser.add_argument_group("I/O")
    io_parser.add_argument("--path_to_data", type=str, default="data/datasets")
    io_parser.add_argument("--output_file", type=str, default="./output.csv")
    
    other_parser = parser.add_argument_group("SELECTION")
    other_parser.add_argument("--gap", default=False, action="store_true", help="By setting this option, uniform data is generated to look like the dataset in shapes. This is for bootstrapping in gap statistic")


def add_douglas_args(parser: argparse.ArgumentParser):
    model_parser = parser.add_argument_group("MODEL")
    model_parser.add_argument("--learning_rate", type=float, default=1e-3)
    model_parser.add_argument("--n_cuts", type=int, default=1)
    model_parser.add_argument("--batch_size", type=int, default=None)
    model_parser.add_argument("--n_epochs", type=int, default=200)
    model_parser.add_argument("--temperature", type=float, default=0.1)
    model_parser.add_argument("--torch", default=False, action="store_true")

    gemini_parser = parser.add_argument_group("GEMINI")
    gemini_parser.add_argument("--distance", type=str, default="wasserstein", choices=["mmd", "wasserstein", "mi"])
    gemini_parser.add_argument("--mode", type=str, default="ovo", choices=["ova", "ovo"])


def add_kauri_args(parser: argparse.ArgumentParser):
    parser.add_argument("--max_depth", type=int, default=None)
    parser.add_argument("--min_samples_split", type=int, default=2)
    parser.add_argument("--min_samples_leaf", type=int, default=1)
    parser.add_argument("--max_leaves", type=int, default=None)
    parser.add_argument("--max_features", type=int, default=None)
    parser.add_argument("--kernel", type=str, default="linear", choices=KERNEL_PARAMS)


def check_args(args):
    assert os.path.exists(args.path_to_data)
    assert args.n_clusters >= 2 or args.n_clusters == -1, (
        "Invalid number of clusters. Either specify 2+ clusters or -1 "
        "to let it to the same number as the dataset target")
    assert 0 < args.subset_size <= 1

    if args.method == "kauri":
        if args.max_depth is not None:
            assert args.max_depth >= 2
        assert args.min_samples_split >= 2
        assert args.min_samples_leaf >= 1
        if args.max_leaves is not None:
            assert args.max_leaves >= 2
        if args.max_features is not None:
            assert args.max_features >= 1
    elif args.method == "douglas":
        assert args.learning_rate > 0
        assert args.n_cuts > 0
        if args.batch_size is not None:
            assert args.batch_size >= 1
        assert args.n_epochs >= 1
        assert args.temperature > 0


if __name__ == "__main__":
    test_args = get_args()
    print(test_args)
