import os
import sys
import warnings
import argparse
import ast
import os
import random

from datetime import datetime

import numpy as np
import ruamel.yaml as yaml
import torch
from test_cmd import create_cmd
from openpto.config import get_args, get_logger, load_conf, setup_seed
from openpto.expmanager import ExpManager
from openpto.method.Models.wrapper_loss import get_loss_fn
from openpto.method.Solvers.wrapper_solver import solver_wrapper
from openpto.problems.wrapper_prob import problem_wrapper

def get_args(probs,opt_m):
    parser = argparse.ArgumentParser()
    # basic
    parser.add_argument(
        "--problem",
        type=str,
        choices=[
            "budgetalloc",
            "bipartitematching",
            "cubic",
            "portfolio",
            "knapsack",
            "energy",
            "advertising",
            "shortestpath",
            "TSP",
        ],
        default="knapsack",
    )
    parser.add_argument("--config_path", type=str, default="./openpto/config/probs/")
    parser.add_argument(
        "--method_path", type=str, default="openpto/config/models/default.yaml"
    )
    parser.add_argument("--trained_path", type=str, default="")
    parser.add_argument("--loss_path", type=str, default="")
    parser.add_argument(
        "--opt_model",
        type=str,
        choices=[
            "mse",
            "dfl",
            "bce",
            "ce",
            "mae",
            "spo",
            "pointLTR",
            "pairLTR",
            "listLTR",
            "intopt",
            "blackbox",
            "blackboxSolver",
            "identity",
            "identitySolver",
            "lodl",
            "nce",
            "SAApointLTR",
            "SAApairwiseLTR",
            "SAAlistwiseLTR",
            "lodl",
            "perturb",
            "cpLayer",
        ],
        default="mse",
    )
    parser.add_argument(
        "--pred_model",
        type=str,
        choices=[
            "dense",
            "cvr",
            "cv_mlp",
            "ConvNet",
            "Resnet18",
            "CombResnet18",
            "PureConvNet",
        ],
        default="dense",
    )
    parser.add_argument(
        "--solver",
        type=str,
        choices=[
            "gurobi",
            "neural",
            "heuristic",
            "cvxpy",
            "ortools",
            "qptl",
        ],
        default="gurobi",
    )
    parser.add_argument("--gpu", type=str, default="-1", help="Visible GPU")
    # training
    parser.add_argument("--loadnew", type=ast.literal_eval, default=False)
    parser.add_argument("--opt_name", type=str, default="gd", choices=["gd", "sgd"])
    parser.add_argument("--n_epochs", type=int, default=0)
    parser.add_argument("--n_ptr_epochs", type=int, default=0)
    parser.add_argument("--earlystopping", type=ast.literal_eval, default=True)
    parser.add_argument("--patience", type=int, default=40)
    parser.add_argument("--seed", type=int, default=2023)
    parser.add_argument("--lr", type=float, default=1e-2)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--use_lr_scheduling", action="store_true")
    parser.add_argument("--lr_milestone_1", type=int, default=100)
    parser.add_argument("--lr_milestone_2", type=int, default=200)
    parser.add_argument("--l1_weight", type=float, default=0)
    parser.add_argument("--l2_weight", type=float, default=0)
    # data
    parser.add_argument("--data_dir", type=str, default="./openpto/data/")
    parser.add_argument("--do_debug", action="store_true")
    parser.add_argument("--instances", type=int, default=400)
    parser.add_argument("--testinstances", type=int, default=200)
    # debug
    parser.add_argument("--valfreq", type=int, default=1)
    parser.add_argument("--savefreq", type=int, default=-1)
    parser.add_argument("--prefix", type=str, default="default")
    # model
    parser.add_argument("--n_layers", type=int, default=2)
    parser.add_argument("--n_hidden", type=int, default=32)
    parser.add_argument("--pooling", type=str, default="mean")
    parser.add_argument("--activation", type=str, default="relu")
    parser.add_argument("--kernel_size", type=int, default=1)
    parser.add_argument("--rescon", type=bool, default=False)
    ###CMD set
    #rgs 
    ###test for our method
    test_cmd=create_cmd()
    if probs not in test_cmd.keys():
        raise ValueError(f"The problem [{probs}] is not supported.")
    if opt_m not in test_cmd[probs].keys():
        raise ValueError(f"The opt model [{opt_m}] is not supported.")
    mock_args=test_cmd[probs][opt_m]
    args= parser.parse_args(mock_args)
    args.data_dir = os.path.join(args.data_dir, args.problem)
    return args

def load_conf(prob_path: str = None, method_path: str = None, prob_name: str = None):
    """
    Function to load config file.

    Parameters
    ----------
    prob_path : str
        Path to load config file. Load default configuration if set to `None`.
    method_path : str
        Path to load method config file. Necessary if ``path`` is set to `None`.
    prob_name : str
        Name of the corresponding problem. Necessary if ``path`` is set to `None`.

    Returns
    -------
    conf : argparse.Namespace
        The config file converted to Namespace.

    """
    if prob_path=="./openpto/config/probs/":
         prob_path =os.path.join(prob_path, args.problem+".yaml")

    if os.path.exists(prob_path) is False:
        raise ValueError(f"The configuration file, [{prob_path}] is not provided.")

    conf = yaml.safe_load(open(prob_path, "r").read())
    conf["models"] = yaml.safe_load(open(method_path, "r").read())
    # conf = argparse.Namespace(**conf)
    return conf

if __name__ == "__main__":
    print("-------------------------------------------------------------------------------------------------------------------------------")
    print("Please select from the following optimization problems:")
    print("Budgetalloc","BipartiteMatching","TopK(Cubic)", "Portfolio","Knapsack(Gen)","Knapsack(Energy)","Scheduling(Energy)")
    print("-------------------------------------------------------------------------------------------------------------------------------")
    print("Please select from the following DFL methods:")
    print("SPO",
            "Org-Lt",
            "Org-Pt",
            "Org-Pr",
            "Blackbox",
            "Identity",
            "LODL",
            "NCE",
            "SAA-Pr",
            "SAA-Lt",
            "SAA-Pt",
            "CpLayer",
            "Two-stage")
    print("-------------------------------------------------------------------------------------------------------------------------------")
    por=input("Problem name:")
    print("-------------------------------------------------------------------------------------------------------------------------------")
    opt_m=input("DFL method:")

    args=get_args(por,opt_m)
    conf = load_conf(args.config_path, args.method_path, args.problem)
    setup_seed(args.seed)

    # set logger
    logger = get_logger(args, conf)
    logger.info(f" {args.bkup_log_dir}\n {args.log_dir}\n args: {args} \n")

    # Load problem
    logger.info(f" dataset configs: {conf['dataset']} \n")
    logger.info(f" model configs: {conf['models'][args.opt_model]} \n")
    logger.info(f" Loading [{args.problem}] Problem...")
    problem = problem_wrapper(args, conf)

    # Load solver
    logger.info(f" Loading [{args.solver}] solver ...")
    ptoSolver = solver_wrapper(args, conf, problem)

    # Load loss function
    logger.info(f" Loading [{args.opt_model}] Loss Function...")
    loss_fn = get_loss_fn(args=args, ptoSolver=ptoSolver,coeff_dim=problem.get_model_shape()[1], conf=conf)

    # load exp manager
    pred_model_args = {
        "ipdim": problem.get_model_shape()[0],
        "opdim": problem.get_model_shape()[1],
        "out_act": problem.get_output_activation(),
    }
    exp = ExpManager(pred_model_args, args=args, conf=conf, logger=logger)

# Train neural network with a given loss function
    logger.info(
        f" Start training [{args.pred_model}] model on [{args.opt_model}] loss..."
    )
    exp.run(problem, loss_fn, ptoSolver, n_epochs=args.n_epochs, do_debug=args.do_debug)
