import argparse
import torch
import os
import sys

# Silence CUDA context warnings
import warnings
warnings.filterwarnings("ignore", message=".*Attempting to run cuBLAS, but there was no current CUDA context.*")

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from train import grid_search
from data.Dataset import TimeSeriesDataset, Data

# ==========================================================
# Main
# ==========================================================
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="VAR3")
    parser.add_argument("--series", type=int, default=1)
    parser.add_argument("--subject", type=int, default=1)
    parser.add_argument("--lag", type=int, default=2)
    parser.add_argument("--seed", type=int, default=2025)
    parser.add_argument("--num_workers", type=int, default=1, help="Number of parallel workers for grid search")
    parser.add_argument("--exec_idx", type=int, default=1, help="Execution index for this worker (1-based)")
    parser.add_argument("--penalty_type", type=str, required=True,
                        choices=["Fast_Shap", "Shapley", "Jacob_F", "Jacob_L1"],
                        help="Penalty used during training")

    args = parser.parse_args()
    model_type = 'ResidualMLP'
    
    if args.penalty_type in ['Fast_Shap', 'Shapley']:
        importance_type = 'Shapley'
    elif args.penalty_type in ['Jacob_F', 'Jacob_L1']:
        importance_type = 'Jacobian'
    elif args.penalty_type == 'Layer_Weight':
        importance_type = 'Layer_Weight'
    else:
        raise ValueError(f"Unknown penalty_type: {args.penalty_type}")
    ignore_diagonal = True if args.dataset in ['DREAM3', 'DREAM4'] else False
    batch_size = 512 if args.dataset in ['CausalTime'] else -1

    if torch.cuda.is_available():
        n_gpu = torch.cuda.device_count()
        device = f'cuda:{(args.exec_idx - 1) % n_gpu}'
    else:
        device = 'cpu'

    save_path = f"./results/simulation_{args.dataset}_{args.series}_{args.subject}_{model_type}_{args.penalty_type}.csv"

    # Load dataset
    data_extractor = Data('./data', args.dataset, args.series, args.subject)
    X, network, gene_names = data_extractor.load_data()
    dataset = TimeSeriesDataset(X, args.lag, Norm=True, device=device)

    int_lambda = [1e-4, 1e-3, 1e-2, 1e-1] if args.penalty_type in ['Fast_Shap'] else [0.0]
    data_dim = dataset.output_dim
    hidden_dim = [50, 100]

    # Hyperparameter grid
    param_grid = {
        'lr': [5e-4, 1e-3, 5e-3],
        'hidden_dim': hidden_dim,
        'layers': [1, 2, 3, 4, 5],
        'dropout': [0.1, 0.2],
        'ind_lambda': [1e-4, 1e-3, 1e-2, 1e-1],
        'int_lambda': int_lambda,
        'weight_decay': [1e-5]
    }

    results = grid_search(dataset=dataset, network=network, model_type=model_type, penalty_type=args.penalty_type, importance_type=importance_type, ignore_diagonal=ignore_diagonal, param_grid=param_grid, batch_size=batch_size, save_path=save_path, num_workers=args.num_workers, exec_idx=args.exec_idx, device=device, seed=args.seed)

    if results is not None and len(results) > 0:
        write_header = not os.path.exists(save_path) or os.path.getsize(save_path) == 0
        results.to_csv(save_path, mode='a', index=False, header=write_header)