import argparse
import random
import numpy as np
import torch
import os

# Define the available models and sizes
proprietary_models = ["gpt-4o", "gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano"]
models = ["qwen2.5", "gpt-4o", "gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano", "qwen3"]
sizes = [7, 14, 32, 8]
runtypes = ['marginal', 'copula', 'full', 'joint', 'marginal_cot', 'full_cot', 'joint_cot']
datatypes = ['e_commerce', 'mobility', 'population']

# Function to set the random seed for reproducibility
def seed_everything(seed: int):    
    """
    Set the random seed for all relevant libraries to ensure reproducibility.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True, warn_only=True)
    os.environ["PYTHONHASHSEED"] = str(seed)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"


# Function to parse command-line arguments
def get_args():
    """
    Parses the command-line arguments and returns the configuration object.
    """
    parser = argparse.ArgumentParser(description='')

    # Define command-line arguments
    parser.add_argument('--data_dir', type=str, default='./data/', help='data directory')
    parser.add_argument('--config_dir', type=str, default='./config/', help='configuration directory')
    parser.add_argument('--result_dir', type=str, default='./results/', help='result directory')
    parser.add_argument('--run_type', type=str, default='copula', choices=runtypes, help='type of run')
    parser.add_argument('--data_type', type=str, default='e_commerce', choices=datatypes, help='type of data')
    parser.add_argument('--model', type=str, default='gpt-4.1-nano', choices=models, help='VLM model to use')
    parser.add_argument('--model_size', type=int, default=7, choices=sizes, help='model size (e.g., 7B)')
    parser.add_argument('--device', type=str, default='0,1,2,3', help='device(s) to use (e.g., "0,1" for GPUs)')
    parser.add_argument('--n_quantiles', type=int, default=4, help='number of quantiles for discretization')
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--n_bins', type=int, default=6, help='number of bins for histogram')
    parser.add_argument('--n_sub_bins', type=int, default=8, help='number of sub-bins for histogram refinement')
    parser.add_argument('--n_real_samples', type=int, default=2000, help='number of real samples')
    parser.add_argument('--n_samples', type=int, default=2000, help='number of synthetic samples to generate')
    parser.add_argument("--n_plans", type=int, default=5, help="number of plans returned by LLM in batch mode")
    parser.add_argument("--n_joints", type=int, default=5, help="number of groups of variables for joint distribution")
    parser.add_argument("--topk_diff", type=int, default=3, help="top-k largest differences to compute")
    parser.add_argument("--batch_size", type=int, default=20, help="batch size for samples generated per iteration")
    parser.add_argument('--min_div', type=int, default=1, help='minimum samples for comparing distributions')
    parser.add_argument('--debug', action='store_true', default=False, help='debug mode')
    parser.add_argument('--save', action='store_true', default=False, help='save results')
    parser.add_argument('--concert', action='store_true', default=False, help='concert mode flag')
    parser.add_argument('--eval', action='store_true', default=False, help='evaluate per iteration')

    # Parse arguments
    args = parser.parse_args()

    # Determine model path based on model selection
    if args.model == "qwen2.5":
        args.model_path = f"/your_path_to_ckpts/Qwen2.5-{args.model_size}B-Instruct"
    elif args.model == "qwen3":
        args.model_path = f"/your_path_to_ckpts/Qwen3-{args.model_size}B"
    elif args.model in proprietary_models:
        pass
    else:
        raise ValueError("Model not found")

    # Modify the model name if not a proprietary model
    if args.model not in proprietary_models:
        args.model = args.model + f"_{args.model_size}B"

    # Set result directory path based on parameters
    args.result_dir = os.path.join(
        args.result_dir, args.data_type,
        f"{args.model}_{args.run_type}_{args.n_real_samples}realsamples_{args.n_samples}generate_"
        f"{args.n_bins}bins_{args.n_sub_bins}subbins_"
        f"{args.topk_diff}topk_{args.n_joints}joints_"
        f"{args.batch_size}batchsize_{args.n_plans}plans",
    )

    # Modify result directory if concert mode is enabled
    if args.concert:
        args.result_dir += "_concert"

    # Set the seed for all libraries
    seed_everything(args.seed)

    return args
