import argparse
import logging
import sys
import json
import os
import torch
from pyprojroot import here as project_root
import setproctitle
import nni

sys.path.insert(0, str(project_root()))

from fs_mol.modules.graph_feature_extractor import (
    add_graph_feature_extractor_arguments,
    make_graph_feature_extractor_config_from_args,
)
from fs_mol.utils.cli_utils import add_train_cli_args, set_up_train_run
from fs_mol.utils.hypro_utils import (
    HyProTrainerConfig,
    HyProTrainer,
)
setproctitle.setproctitle("hypro")
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
logger = logging.getLogger(__name__)
torch.set_num_threads(8)

params = {
    'ood':3,
    'ood2': 0.1,
    'ood3': 0.1,
    'ood4': 0,
    'hyper_layer_num':2,
    'hyper_dropout':0.3,
    'sample_start':2,
    'sample_end':7,
    'sample_div':128,
    
    'lr':0.0001,
    'tasks_per_batch':16,
}


def parse_command_line():
    parser = argparse.ArgumentParser(
        description="Train a HyPro model on molecules.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    add_train_cli_args(parser)

    parser.add_argument(
        "--features",
        type=str,
        choices=[
            "gnn",
            "ecfp",
            "pc-descs",
            "ecfp+fc",
            "pc-descs+fc",
            "gnn+ecfp+fc",
            "gnn+ecfp+pc-descs+fc",
        ],
        default="gnn+ecfp+fc",
        help="Choice of features to use",
    )
    parser.add_argument(
        "--distance_metric",
        type=str,
        choices=["mahalanobis", "euclidean"],
        default="mahalanobis",
        help="Choice of distance to use.",
    )
    add_graph_feature_extractor_arguments(parser)

    parser.add_argument("--support_set_size", type=int, default=64, help="Size of support set")
    parser.add_argument(
        "--query_set_size",
        type=int,
        default=256,
        help="Size of target set. If -1, use everything but train examples.",
    )
    parser.add_argument(
        "--tasks_per_batch",
        type=int,
        default=16,
        help="Number of tasks to accumulate gradients for.",
    )

    parser.add_argument("--batch_size", type=int, default=512, help="Number of examples per batch.")
    parser.add_argument(
        "--num_train_steps", type=int, default=10000, help="Number of training steps."
    )
    parser.add_argument(
        "--validate_every",
        type=int,
        default=50,
        help="Number of training steps between model validations.",
    )
    parser.add_argument(
        "--validation-support-set-sizes",
        type=json.loads,
        default=[4,128],#16,128
        help="JSON list selecting the number of datapoints sampled as support set data during evaluation through finetuning on the validation tasks.",
    )

    parser.add_argument(
        "--validation-query-set-size",
        type=int,
        default=512,
        help="Maximum number of datapoints sampled as query data during evaluation through finetuning on the validation tasks.",
    )

    parser.add_argument(
        "--validation-num-samples",
        type=int,
        default=5,
        help="Number of samples considered for each train set size for each validation task during evaluation through finetuning.",
    )
    parser.add_argument("--lr", type=float, default=0.0001, help="Learning rate")
    parser.add_argument(
        "--clip_value", type=float, default=1.0, help="Gradient norm clipping value"
    )
    parser.add_argument(
        "--pretrained_gnn",
        type=str,
        default=None,
        help="Path to a pretrained GNN model to use as a starting point.",
    )
    parser.add_argument("--ood", type=int, default=3, help="Size of support set")
    parser.add_argument("--ood4", type=float, default=0.1, help="Size of support set")
    parser.add_argument("--ood2", type=float, default=0.1, help="Size of support set")
    parser.add_argument("--ood3", type=float, default=0.1, help="Size of support set")
    parser.add_argument("--hyper_layer_num", type=int, default=2, help="Size of support set")
    parser.add_argument("--hyper_dropout", type=float, default=0.2, help="Size of support set")
    parser.add_argument("--sample_start", type=int, default=4, help="Size of support set")
    parser.add_argument("--sample_end", type=int, default=5, help="Size of support set")
    parser.add_argument("--sample_div", type=int, default=512, help="Size of support set")
    
    args = parser.parse_args()
    return args


def make_trainer_config(args: argparse.Namespace) -> HyProTrainerConfig:
    args.ood=params['ood']
    args.ood4=params['ood4']
    args.ood2=params['ood2']
    args.ood3=params['ood3']
    args.hyper_layer_num=params['hyper_layer_num']
    args.hyper_dropout=params['hyper_dropout']
    args.sample_start=params['sample_start']
    args.sample_end=params['sample_end']
    args.sample_div=params['sample_div']
    args.lr=params['lr']
    args.tasks_per_batch=params['tasks_per_batch']

    return HyProTrainerConfig(
        graph_feature_extractor_config=make_graph_feature_extractor_config_from_args(args),
        used_features=args.features,
        distance_metric=args.distance_metric,
        batch_size=args.batch_size,
        tasks_per_batch=args.tasks_per_batch,
        support_set_size=args.support_set_size,
        query_set_size=args.query_set_size,
        validate_every_num_steps=args.validate_every,
        validation_support_set_sizes=tuple(args.validation_support_set_sizes),
        validation_query_set_size=args.validation_query_set_size,
        validation_num_samples=args.validation_num_samples,
        num_train_steps=args.num_train_steps,
        learning_rate=args.lr,
        clip_value=args.clip_value,
        ood=args.ood,
        ood4=args.ood4,
        ood2=args.ood2,
        ood3=args.ood3,
        hyper_layer_num=args.hyper_layer_num,
        hyper_dropout=args.hyper_dropout,
        sample_start=args.sample_start,
        sample_end=args.sample_end,
        sample_div=args.sample_div,
    )


def main():
    args = parse_command_line()
    config = make_trainer_config(args)

    out_dir, dataset, aml_run = set_up_train_run(
        f"HyPro_{config.used_features}", args, torch=True
    )
    d=os.path.join(out_dir, "0hp.json")
    with open(d, 'w', encoding='utf-8') as f:
        json.dump(params, f)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_trainer = HyProTrainer(config=config).to(device)

    logger.info(f"\tDevice: {device}")
    logger.info(f"\tNum parameters {sum(p.numel() for p in model_trainer.parameters())}")
    logger.info(f"\tModel:\n{model_trainer}")

    if args.pretrained_gnn is not None:
        logger.info(f"Loading pretrained GNN weights from {args.pretrained_gnn}.")
        model_trainer.load_model_gnn_weights(path=args.pretrained_gnn, device=device)

    model_trainer.train_loop(out_dir, dataset, device, aml_run)


if __name__ == "__main__":
    try:
        #optimized_params = nni.get_next_parameter()
        #params.update(optimized_params)
        print(params)
        
        main()
    except Exception:
        import traceback
        import pdb

        _, value, tb = sys.exc_info()
        traceback.print_exc()
        pdb.post_mortem(tb)
