import os
import argparse
from argparse import ArgumentParser
from pathlib import Path

import pytorch_lightning as pl

from equislt.methods import (TRAIN_METHODS, PRUNE_METHODS)


def dataset_args(parent_parser: ArgumentParser):
    SUPPORTED_DATASETS = [
        "Cora", "CiteSeer", "RotMNIST", "FlipRotMNIST",
        "MUTAG", "PTC", "PROTEINS", "IMDBBINARY", "NCI1",
    ]

    parser = parent_parser.add_argument_group("dataset")
    parser.add_argument("--dataset", choices=SUPPORTED_DATASETS, type=str, required=True)
    parser.add_argument("--data_dir", type=Path, required=True)
    parser.add_argument("--num_workers", type=int, default=os.getenv('SLURM_CPUS_PER_TASK', 1))
    parser.add_argument("--batch_size", type=int, default=64)

    return parent_parser


def checkpoint_args(parent_parser: ArgumentParser):
    parser = parent_parser.add_argument_group("checkpointer")
    parser.add_argument("--save_checkpoint", action="store_true")
    parser.add_argument("--checkpoint_dir", default=Path("trained_models"), type=Path)
    parser.add_argument("--checkpoint_frequency", default=1, type=int)
    return parent_parser


def logger_args(parent_parser: ArgumentParser):
    parser = parent_parser.add_argument_group("logger")
    parser.add_argument("--name")
    parser.add_argument("--project")
    parser.add_argument("--entity", default=None, type=str)
    parser.add_argument("--group", default=None, type=str)
    parser.add_argument("--wandb", action="store_true")
    parser.add_argument("--offline", action="store_true")
    return parent_parser


def parse_args_train() -> argparse.Namespace:
    """Parses dataset, augmentation, pytorch lightning, model specific and additional args.

    First adds shared args such as dataset, augmentation and pytorch lightning args, then pulls the
    model name from the command and proceeds to add model specific args from the desired class. If
    wandb is enabled, it adds checkpointer args. Finally, adds additional non-user given parameters.

    Returns:
        argparse.Namespace: a namespace containing all args needed for pretraining.
    """

    parser = argparse.ArgumentParser()
    # add a seed
    parser.add_argument("--seed", type=int, default=0)
    # add shared arguments
    parser = dataset_args(parser)
    parser = checkpoint_args(parser)
    parser = logger_args(parser)
    # add pytorch lightning trainer args
    parser = pl.Trainer.add_argparse_args(parser)
    # add method-specific arguments
    parser.add_argument("--method", type=str, required=True)
    # THIS LINE IS KEY TO PULL THE MODEL NAME
    temp_args, _ = parser.parse_known_args()
    # add model specific args
    parser = TRAIN_METHODS[temp_args.method].add_model_specific_args(parser)
    # parse args
    return parser.parse_args()


def parse_args_prune() -> argparse.Namespace:
    """Parses dataset, augmentation, pytorch lightning, model specific and additional args.

    First adds shared args such as dataset, augmentation and pytorch lightning args, then pulls the
    model name from the command and proceeds to add model specific args from the desired class. If
    wandb is enabled, it adds checkpointer args. Finally, adds additional non-user given parameters.

    Returns:
        argparse.Namespace: a namespace containing all args needed for pretraining.
    """

    parser = argparse.ArgumentParser()
    # add a seed
    parser.add_argument("--seed", type=int, default=0)
    # add shared arguments
    parser = dataset_args(parser)
    #  parser = logger_args(parser)
    # add method-specific arguments
    parser.add_argument("--target_net_dir", default=Path("trained_models"), type=Path)
    parser.add_argument("--method", type=str, required=True)
    parser.add_argument("--gpus", type=int, default=None, nargs='+')
    # THIS LINE IS KEY TO PULL THE MODEL NAME
    temp_args, _ = parser.parse_known_args()
    # add model specific args
    parser = PRUNE_METHODS[temp_args.method].add_model_specific_args(parser)
    # parse args
    return parser.parse_args()
