# import os
# from datetime import datetime

import argparse
from util_knapsack import run_experiments
from pyepo.cspo_examples.knapsack import knapsack_params
from pyepo.cspo_examples.knapsack import knapsack_problem

"""## Run Knapsack Problem :

This notebook runs experiments on the knapsack problem with the following setup:

#### Methods being compared:

- cspo+ - Constrained Shortest Path Optimization Plus
- is_cspo+ - Importance Sampling CSPO+
- mse - Mean Squared Error
- is_mse - Importance Sampling MSE
- ws_is_cspo+ - Weighted Sampling Importance CSPO+
- mse_is_cspo+ - MSE Importance CSPO+

# Comparative Analysis of Multiple CSPO+ Methods on Standard Knapsack Problem

## Experiment Configuration:
- Methods: IS_MSE, MSE_IS_CSPO+, IS_CSPO+
- Problem Size: 5 items, 3 features
- Data Scale: 1000 training, 300 validation/test, 500 CP samples
- Weight Complexity: 1st degree polynomial
- Configuration: With truncation, 10 instance runs
"""


def get_args():
    parser = argparse.ArgumentParser(description="Knapsack Experiment Parameters")
    # Data sizes
    parser.add_argument("--num_data", type=int, default=1000)
    parser.add_argument("--val_num_data", type=int, default=3000)
    parser.add_argument("--test_num_data", type=int, default=3000)
    parser.add_argument("--cp_num_data", type=int, default=3000)

    # Problem definition
    # Model parameters
    parser.add_argument("--num_feat", type=int, default=10)
    parser.add_argument("--num_var", type=int, default=5)
    parser.add_argument("--num_const", type=int, default=1)

    parser.add_argument("--weight_deg", type=int, default=2)
    parser.add_argument("--noise_width", type=float, default=1.0)
    parser.add_argument("--cost_deg_list", type=int, nargs="+", default=[2,4,6,8], help="List of cost polynomial degrees")
    parser.add_argument("--alpha_list", type=float, nargs="+", default=[0.2], help="List of alpha values")

    # Experiment settings
    parser.add_argument("--num_epochs", type=int, default=50)
    parser.add_argument("--num_process", type=int, default=1) # number of processes for parallel computation
    parser.add_argument("--num_repeat", type=int, default=5)# number of instances to solve (random seeds set)

    # Flags
    parser.add_argument("--test_mode", action="store_false")
    parser.add_argument("--cspo", action="store_true")

    # Solver
    parser.add_argument("--solve_ratio", type=float, default=1.0)

    # Learning rate and paths
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--save_path", type=str, default="./results/")

    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    capacity = 1
    

    print("Solving the Knapsack Problem")
    print(f"Training cspo+:{args.cspo}\n")

    file_name = (
        f"knapsack_results_num_data={args.num_data}_feat={args.num_feat}_num_var={args.num_var}_num_const=1_CSPO={args.cspo}_solve_ratio={args.solve_ratio}.txt"
    )
    if args.cspo:
        method_list = [
                   "cspo+",
                   "cspo+_is",
                   "cspo+_T"
                   ]
    else:
        method_list = [
                    "mse",
                   "mse_T",
                   "mse_is"]
        
    # Learning rate dictionary
    learning_rate = {
        "mse":args.lr,
        "mse_T":args.lr,
        "mse_is":2*args.lr ,
        "cspo+":4*args.lr,
        "cspo+_is":4*args.lr,
        "cspo+_T":4*args.lr,
        "cspo+_mse":args.lr,
        "cspo+_mse_is":2*args.lr,
        "cspo+_mse_T":2*args.lr}
        

    # Generate knapsack problem parameters
    params_list = []
    
    for cost_deg in args.cost_deg_list:
        params = knapsack_params(num_feat = args.num_feat,
            num_item = args.num_var,
            weight_deg =  args.weight_deg,
            noise_width =  args.noise_width,
            cost_deg =  cost_deg,
            capacity = capacity
        )
        params_list.append(params)


    # Run experiments
    run_experiments(
        knapsack_problem,
        args.num_repeat,
        args.num_data,
        args.val_num_data,
        args.test_num_data,
        args.cp_num_data,
        params_list,
        method_list,
        learning_rate,
        args.save_path,
        file_name,
        num_epochs=args.num_epochs,
        num_process=args.num_process,
        alphas=args.alpha_list,
        test_mode=args.test_mode,
        solve_ratio=args.solve_ratio
    )