import os
import argparse
import torch
import numpy as np
import random
import yaml

from explainer.sssgexplainer import sssgexplainer

from configs.config import args as config_defaults

def parse_args():
    parser = argparse.ArgumentParser(description="Hyperparameter Tuning")

    ### GNN hyperparameters ###
    parser.add_argument('--dataset', type=str, choices=['MUTAG', 'BA3', 'FC', 'MNIST'],
                        help='Dataset name')
    parser.add_argument('--model', type=str, choices=['GCN', 'GIN'], help='Model name')
    parser.add_argument('--hidden', type=int, help='Number of hidden units')
    parser.add_argument('--nlayers', type=int, help='Number of hidden layers')
    parser.add_argument('--batch_size', type=int, help='Batch size')
    parser.add_argument('--dropout', type=float, help='Dropout ratio')
    parser.add_argument('--pool_type', type=str, choices=['mean', 'sum', 'max'], help='Pooling type')
    parser.add_argument('--use_jk', default=True, help='Use Jumping Knowledge.')

    ### Explainer hyperparameters ###
    parser.add_argument('--explainer_name', type=str, help='Explainer to be used')
    parser.add_argument('--epochs', type=int, help='Number of epochs to train')
    parser.add_argument('--lr', type=float, help='Learning rate.')
    
    parser.add_argument('--round1_epochs', type=int, help='Round 1 training epochs')
    parser.add_argument('--round1_lr', type=float, help='Round 1 learning rate')
    parser.add_argument('--alpha0', type=float, help='Base quantile cutoff (0, 0.5) for pseudo-labeling')
    parser.add_argument('--c', type=float, help='Skewness adjustment factor (>0) for asymmetric quantiles')
    parser.add_argument('--w', type=float, help='weight for loss')

    # Optional second guided stage (Round 2) controls
    parser.add_argument('--second_phase_epochs', type=int, help='Round 2 training epochs (defaults to round1_epochs)')
    parser.add_argument('--second_phase_lr', type=float, help='Round 2 learning rate (defaults to round1_lr)')
    parser.add_argument('--second_alpha0', type=float, help='Round 2 base quantile cutoff (defaults to alpha0)')
    parser.add_argument('--second_c', type=float, help='Round 2 skewness adjustment factor (defaults to c)')
    parser.add_argument('--second_w', type=float, help='Loss pos_weight for Round 2 (defaults to w)')
    
    ### etc ###
    parser.add_argument('--ckpt_path', type=str, help='Location for saving checkpoints')
    parser.add_argument('--gpu', type=int, help='GPU device id to use')
    parser.add_argument('--seed', type=int, help='Random seed')

    return parser.parse_args()


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def fill_missing_args(args, yaml_config, config_defaults):
    print("=== Argument Priority Check ===")
    print(f"Terminal args: {[k for k, v in vars(args).items() if v is not None]}")
    print(f"YAML config keys: {list(yaml_config.keys())}")
    print(f"Config defaults keys: {list(vars(config_defaults).keys())}")
    
    # Apply YAML config for missing args
    yaml_applied = []
    for key, value in yaml_config.items():
        if getattr(args, key, None) is None:
            setattr(args, key, value)
            yaml_applied.append(key)
    print(f"Applied from YAML: {yaml_applied}")

    # Apply config defaults for still missing args
    config_applied = []
    for key, value in vars(config_defaults).items():
        if getattr(args, key, None) is None:
            setattr(args, key, value)
            config_applied.append(key)
    print(f"Applied from config.py: {config_applied}")
    print("=" * 40)


if __name__ == "__main__":
    args = parse_args()

    yaml_config = {}
    if args.dataset: 
        config_path = os.path.join("configs", f"{args.dataset}.yaml")
        if os.path.exists(config_path):
            with open(config_path, "r") as f:
                yaml_data = yaml.safe_load(f)
                yaml_config = yaml_data.get(args.explainer_name, {})

    fill_missing_args(args, yaml_config, config_defaults)

    set_seed(args.seed)
    torch.set_num_threads(4)
    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')

    explainer = eval(args.explainer_name)(args, device=device)
    explainer.train_test(args)
