# import os
# from datetime import datetime

import argparse
from util_cover import run_experiments
from pyepo.cspo_examples.cover import cover_params
from pyepo.cspo_examples.cover import cover_problem

"""## Run Cover Problem :

This notebook runs experiments on the cover problem with the following setup:

#### Methods being compared:

- cspo+ - Constrained Smart-predict and Optimization Plus
- is_cspo+ - Importance Sampling CSPO+
- mse - Mean Squared Error
- is_mse - Importance Sampling MSE
- ws_is_cspo+ - Weighted Sampling Importance CSPO+
- mse_is_cspo+ - MSE Importance CSPO+

# Comparative Analysis of Multiple CSPO+ Methods on Standard Knapsack Problem

## Experiment Configuration:
- Methods: IS_MSE, MSE_IS_CSPO+, IS_CSPO+
- Problem Size: 5 items, 3 features
- Data Scale: 1000 training, 300 validation/test, 500 CP samples
- Weight Complexity: 1st degree polynomial
- Configuration: With truncation, 10 instance runs
"""


def get_args():
    parser = argparse.ArgumentParser(description="Cover Experiment Parameters")
    # Data sizes
    parser.add_argument('--num_data', type=int, default=1000)
    parser.add_argument('--val_num_data', type=int, default=3000)
    parser.add_argument('--test_num_data', type=int, default=3000)
    parser.add_argument('--cp_num_data', type=int, default=1000)

    # Problem definition
    parser.add_argument('--num_feat', type=int, default=10)
    parser.add_argument('--num_var', type=int, default=5)
    parser.add_argument('--num_const', type=int, default=2)

    parser.add_argument('--weight_deg', type=int, default=2)
    parser.add_argument('--noise_width', type=float, default=1.0)
    parser.add_argument('--cost_deg_list', type=int, nargs='+', default=[2,4,6,8], help='List of cost polynomial degrees')
    parser.add_argument('--alpha_list', type=float, nargs='+', default=[0.2], help='List of alpha values')

    # Experiment settings
    parser.add_argument('--num_epochs', type=int, default=100)
    parser.add_argument('--num_process', type=int, default=1)
    parser.add_argument('--num_repeat', type=int, default=5)

    # Flags
    parser.add_argument('--test_mode', action='store_false')
    parser.add_argument('--cspo', action='store_true')

    # Solver
    parser.add_argument('--solve_ratio', type=float, default=1.0)

    # Learning rate and paths
    parser.add_argument('--lr', type=float, default=1e-2)
    parser.add_argument('--save_path', type=str, default="./results/")

    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    reqs = [2.9, 7.1]
    assert len(reqs) == args.num_const, "Number of requirements must match the number of constraints"
    print('Solving the Cover Problem')
    print(f'Training cspo+:{args.cspo}\n')
    file_name = (
        f"cover_results_num_data={args.num_data}_feat={args.num_feat}_num_var={args.num_var}_num_const={args.num_const}_CSPO={args.cspo}_lr={args.lr}_solve_ratio={args.solve_ratio}.txt"
    )
    # Define method list

    if args.cspo:
        method_list = [
                   "cspo+",
                   "cspo+_T",
                   "cspo+_is",
                   ]
    else:
        method_list = [
                    'mse',
                   "mse_T",
                   "mse_is"]
    # Learning rate dictionary
    learning_rate = {
         "mse":args.lr,
        "mse_T":args.lr,
        "mse_is":args.lr ,
        "cspo+":4*args.lr,
        "cspo+_is":4*args.lr,
        "cspo+_T":4*args.lr,
        "cspo+_mse":args.lr,
        "cspo+_mse_is":args.lr,
        "cspo+_mse_T":args.lr}
        

    # Generate cover problem parameters
    params_list = []
    for cost_deg in args.cost_deg_list:
        params = cover_params(num_feat = args.num_feat,
                              num_item = args.num_var, 
                              num_reqs = args.num_const, 
                              weight_deg = args.weight_deg, 
                              noise_width = args.noise_width, 
                              cost_deg = cost_deg, 
                              reqs = reqs)
        params_list.append(params)

    # Run experiments
    run_experiments(
        cover_problem,
        args.num_repeat,
        args.num_data,
        args.val_num_data,
        args.test_num_data,
        args.cp_num_data,
        params_list,
        method_list,
        learning_rate,
        args.save_path,
        file_name,
        num_epochs=args.num_epochs,
        num_process=args.num_process,
        alphas=args.alpha_list,
        test_mode=args.test_mode,
        solve_ratio=args.solve_ratio
    )
