import os
from tap import Tap
from typing_extensions import Literal
from os.path import join as opj

STORAGE_DIR = os.path.join(os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), 'storage')  #: :str - Storage root=<Project root>/storage/.
ROOT_DIR = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))  #: str - Project root or Repo root.

class GDLBackboneArgs(Tap):
    r"""
    Correspond to ``backbone`` configs in config files.
    """
    name: str = None  #: Name of the chosen GDL backbone. Currently avaiable: `egnn`, `dgcnn`, `pointtrans`.
    n_layers: int = None  #: Number of the GDL backbone layers.
    hidden_size: int = None  #: Node hidden feature's dimension.
    act_type: str = None  #: Activation
    pool: str = None  #: Readout pooling layer type. 
    norm_type: str = None #: e.g., batch normalization
    dropout_p: float = None  #: Dropout rate.
    kr: int = None  #: Used to build k-nn graph.


class DatasetArgs(Tap):
    r"""
    Correspond to ``dataset`` configs in config files.
    """
    dataset_root: str = None  #: Dataset storage root. Default STORAGE_ROOT/datasets
    data_name: str = None  #: Name of the chosen dataset. e.g., "Track".
    shift_name: str = None  #: Name of the chosen shift. e.g., "pileup", "signal", "fidelity"
    target: str = None  #: specify the shift cases. e.g., "tau", "zp_10", "hse06"
    num_task: int = None  #: `num_task>1` for multi-task learning.
    dataloader_name: str = None  #: The name of data-loader. This project includes `BaseDataloader`.
    feature_type: Literal['only_pos', 'only_x', 'both_x_pos', 'only_ones'] = None  # # only_pos or only_x or both_x_pos or only_ones
    metrics_name: Literal['acc', 'auc', 'mae'] = None
    setting: Literal['No-Info', 'O-Feature', 'Par-Label'] = None  #: Information level proposed in this project.
    OOD_labels: int = None  # The number of accessible OOD labels in the Par-Label level.


class TrainArgs(Tap):
    r"""
    Correspond to ``train`` configs in config files.
    """
    epochs: int = None  #: Max epochs for training stop.
    iters_per_epoch: int = None  #: Used in O-Feature level.
    curr_epoch: int = None  #: Current training epoch. This value should not be set manually.

    train_bs: int = None  #: Batch size for training.
    id_val_bs: int = None  #: Batch size for validation.
    id_test_bs: int = None  #: Batch size for test.
    ood_val_bs: int = None  #: Batch size for validation.
    ood_test_bs: int = None  #: Batch size for test.
    lr: float = None  #: Learning rate.
    wd: float = None  #: Weight decay.


class AlgoArgs(Tap):
    r"""
    Correspond to ``algo`` configs in config files.
    """
    alg_name: str = None  #: Name of the chosen OOD algorithm.
    model_name: str = None  #: Name of the chosen model.
    # Note: If the proposed algorithm needs specific modules besides the GDL encoder and MLP, specify a model, 
    # as we do in DANN/ LRI / DIR algorithms. Otherwise we set `model_name` as "BaseModel".
    coeff: float = None  #: OOD algorithms' hyperparameter(s). Currently, most of algorithms use it as a float value.
 

class PathArgs(Tap):
    r"""
    Correspond to ``path`` configs in config files.
    """
    logging_dir: str = None
    logging_id_metrics: str = None  #: ID performance logging
    logging_ood_metrics: str = None  #: OOD performance logging
    logging_checkpoints: str = None  #: checkpoint logging
    loss_file: str = None  #: loss logging
    result_path: str = None  #: ID performance results
    result_ood_path: str = None  #: OOD performance results

    load_pretrain_ckpt: str = None  #: Load pre-trained model. Used in Par-Label level, default is the model-saving path of No-Info ERM


class CommonArgs(Tap):
    r"""
    Correspond to general configs in config files.
    """
    config_path: str = None  #: (Required) The path for the config file.
    gdl: str = None  #: (Required) The path for the GDL backbone config file.

    seed: int = None  #: Fixed random seed for reproducibility.
    gpu_idx: int = None  #: GPU index.
    device = None  #: Automatically generated by choosing gpu_idx.
    num_workers: int = None  #: Number of workers used by data loaders.
    pipeline: str = None  #: The name of pipeline.

    dataset: DatasetArgs = None
    train: TrainArgs = None
    backbone: GDLBackboneArgs = None
    algo: AlgoArgs = None
    path: PathArgs = None


    def __init__(self, argv):
        super(CommonArgs, self).__init__()
        self.argv = argv
        from GESS.utils.metrics import Metrics
        self.metrics: Metrics = None

    def process_args(self) -> None:
        super().process_args()
        if self.config_path is None:
            raise AttributeError('Please provide command argument --config_path.')
        if not os.path.isabs(self.config_path):
            self.config_path = opj(ROOT_DIR, 'configs', 'core_config', self.config_path)

        if self.gdl is None:
            raise AttributeError('To specify a GDL encoer, please provide command argument --gdl.')
        self.gdl = opj(ROOT_DIR, 'configs', 'backbone_config', f"{self.gdl}.yaml")

        self.dataset = DatasetArgs().parse_args(args=self.argv, known_only=True)
        self.train = TrainArgs().parse_args(args=self.argv, known_only=True)
        self.backbone = GDLBackboneArgs().parse_args(args=self.argv, known_only=True)
        self.algo = AlgoArgs().parse_args(args=self.argv, known_only=True)
        self.path = PathArgs().parse_args(args=self.argv, known_only=True)


def args_parser(argv: list=None):
    r"""
    Arguments parser.

    Args:
        argv: Input arguments.

    Returns:
        General arguments

    """
    common_args = CommonArgs(argv=argv).parse_args(args=argv, known_only=True)
    return common_args

