from model.gnn.grin import MainGRIN
import argparse
import pandas as pd

def get_args():
    parser = argparse.ArgumentParser(description='GRIN method for repetition-invariant polymer representation learning')
    parser.add_argument('--device', type=int, default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--root', type=str, default='data_pyg',
                        help='root directory to store the dataset folder (default: data_pyg)')
    # model
    parser.add_argument('--model_type', type=str, default='gcn',
                        help='GNN model type (default: gcn)')
    parser.add_argument('--graph_pooling', type=str, default='mean',
                        help='graph pooling method (default: mean)')
    parser.add_argument('--drop_ratio', type=float, default=0.1,
                        help='dropout ratio (default: 0.1)')
    parser.add_argument('--num_layer', type=int, default=2,
                        help='number of GNN message passing layers (default: 2)')
    parser.add_argument('--emb_dim', type=int, default=300,
                        help='dimensionality of hidden units in GNNs (default: 300)')
    
    # experiment
    parser.add_argument('--num_workers', type=int, default=0,
                        help='number of workers (default: 0)')
    parser.add_argument('--train_rep', type=list, default=[[1,3]],
                        help='training repetition (default: [1])')
    parser.add_argument('--test_rep', type=list, default=[1,5,10],
                        help='testing repetition (default: [1,5,10])')
    parser.add_argument('--task_name', type=str, default='mt',
                        help='task name (default: mt)')
    parser.add_argument('--polymer_type', type=str, default='homopolymer',
                        help='polymer type (default: homopolymer)')

    # training
    parser.add_argument('--batch_size', type=int, default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs', type=int, default=400,
                        help='number of epochs to train (default: 400)')
    parser.add_argument('--patience', type=int, default=100,
                        help='patience for early stop (default: 100)')
    parser.add_argument('--lr', type=float, default=1e-3,
                        help='Learning rate (default: 1e-3)')

    parser.add_argument('--seeds', type=int, default=1,
                        help='number of experiments (default: 1)')
    parser.add_argument('--use_ck', type=bool, default=False,
                        help='use checkpoint (default: False)')
    parser.add_argument('--by_default', default=False, action='store_true',
                        help='use default configuration for hyperparameters')
    args = parser.parse_args()
    
    if args.by_default:
        try:
            df = pd.read_csv('hyperparameters_summary.csv')
            # Find the row that matches model_type and task_name
            # Note: the CSV uses 'task' for task_name
            condition = (df['model_type'] == args.model_type) & (df['task'] == args.task_name)
            params_row = df[condition]
            
            if not params_row.empty:
                params = params_row.iloc[0] # Get the first matching row
                args.num_layer = int(params['num_layer'])
                args.lr = float(params['lr'])
                args.polymer_type = params['polymer_type']
                print(f"Loaded hyperparameters from CSV for {args.model_type} and {args.task_name}:")
                print(f"  num_layer: {args.num_layer}")
                print(f"  lr: {args.lr}")
                print(f"  polymer_type: {args.polymer_type}")
            else:
                print(f"Warning: No matching hyperparameters found in CSV for model_type='{args.model_type}' and task='{args.task_name}'. Using default CLI/parser values.")
        except FileNotFoundError:
            print(f"Warning: hyperparameters_summary.csv not found at 'hyperparameters_summary.csv'. Using default values.")
        except Exception as e:
            print(f"Warning: Error reading or processing hyperparameters_summary.csv: {e}. Using default values.")
            
    return args

if __name__ == "__main__":
    # rep_times format must be a list:
    # [1] means work on single model 1
    # [1,3] means work on merge model1,3
    args = get_args()
    grin = MainGRIN(args)
    grin.main()

