
import argparse
import os
import numpy as np
import torch
import ipdb
from ast import literal_eval

def parser_csp(filepath):
    constraints = {}
    caseFile=open(filepath,"r")
    lines=caseFile.readlines()


    paramValues=lines[1].split('\t')
    n = int(paramValues[0])
    d = int(paramValues[1])
    m = int(paramValues[2])
    k = int(paramValues[3])
    nb = int(paramValues[4])

    variables = [i for i in range(n)]
    domains = range(d)
    
    for i in range(m):
        constrStrings=lines[i+3].split('|')
        
        var1, var2 = literal_eval(constrStrings[0])
        
        allowedTuples = literal_eval(constrStrings[1])
        
        if (var1, var2) in constraints:
            constraints[(var1, var2)] += allowedTuples
        else:
            constraints[(var1, var2)] = allowedTuples
    
    return variables, constraints, domains

def parser_cop(filepath):
    constraints = {}
    caseFile=open(filepath,"r")
    weithFile = open(filepath+'weight.txt', "r")
    lines=caseFile.readlines()
    weightLines = weithFile.readlines()
    
    # Read head
    paramValues=lines[1].split('\t')
    n = int(paramValues[0])
    d = int(paramValues[1])
    m = int(paramValues[2])
    k = int(paramValues[3])
    nb = int(paramValues[4])

    variables = [i for i in range(n)]
    domains = [list(range(d)) for i in range(n)]
    
    weights = {}
    for line in weightLines:
        var1, var2 = literal_eval(line.split('|')[0])
        weight = literal_eval(line.split('|')[1])
        weights[(var1, var2)] = weight
    
    for i in range(m):
        constrStrings=lines[i+3].split('|')
        
        var1, var2 = literal_eval(constrStrings[0])
        
        allowedTuples = literal_eval(constrStrings[1])
        
        if (var1, var2) in constraints:
            constraints[(var1, var2)] += allowedTuples
            
        else:
            constraints[(var1, var2)] = allowedTuples

    return variables, constraints, domains, weights

def parse_arguments():
    parser = argparse.ArgumentParser(description='Setup for training PPO agent')
    parser.add_argument('--input_dim', type=int, default=100, help='Input dimension size')
    parser.add_argument('--embed_dim', type=int, default=64, help='Embedding dimension size')
    parser.add_argument('--hidden_dim', type=int, default=64, help='Embedding dimension size')
    parser.add_argument('--normalize', action='store_true', help='Whether to normalize input data')
    parser.add_argument('--reward_mode', type=str, default='episodic', choices=['episodic', 'continuous'], help='Mode of reward calculation')
    parser.add_argument('--reward_score_type', type=str, default='BIC', choices=['BIC', 'AIC', 'LL'], help='Type of score for reward calculation')
    parser.add_argument('--reward_regression_type', type=str, default='LR', choices=['LR', 'Ridge', 'Lasso'], help='Type of regression model for reward calculation')
    parser.add_argument('--reward_gpr_alpha', type=float, default=1.0, help='Alpha parameter for Gaussian Process Regression')
    parser.add_argument('--lambda_iter_num', type=int, default=500, help='Number of lambda iterations')
    
    parser.add_argument('--alpha', type=float, default=0.99, help='Alpha value for the score function')
    parser.add_argument('--init_baseline', type=float, default=-1.0, help='Initial baseline value for reward normalization')
   
    parser.add_argument('--smooth', type=float, default=0.1)
    parser.add_argument('--seed', type=float, default=2024)
    parser.add_argument('--decay', type=float, default=0.0)
    parser.add_argument('--log', type=str, default='./PolicyNet/Train_result/log.txt')
    parser.add_argument('--save_path', type=str, default='./PolicyNet/Train_result/best_model.pth')
    
    parser.add_argument('--num_epochs', type=int, default=5)
    
    # Buffer parameters
    parser.add_argument('--buffer_size', type=int, default=128)
    parser.add_argument('--search_num', type=int, default=10000)
    parser.add_argument('--data_path', type=str, default='../datasets/24V_439N_Microwave/')
    parser.add_argument('--buffer_max_buffer_size', type=int, default=5000, help="Maximum buffer size")
    
    
    # model的参数
    parser.add_argument('--log_dir', type=str, default=os.path.join("logs", "ppo"), help="Directory to save logs")
    parser.add_argument('--info', type=str, default='', help="Additional info string for logger")
    parser.add_argument('--print_log', type=bool, default=True)
    parser.add_argument('--enable_pbar', type=bool, default=True)
    parser.add_argument('--gpu', type=int, default=0, help="GPU id to use. Use -1 for CPU")
    
    
    # Agent value network parameters
    parser.add_argument('--v_network_params', type=list, default=[("mlp", 64), ("mlp", 64)], help="Network parameters for value network")
    parser.add_argument('--v_network_optimizer_class', type=str, default="Adam", help="Optimizer class for value network")
    parser.add_argument('--v_network_learning_rate', type=float, default=0.001, help="Learning rate for value network")
    parser.add_argument('--v_network_act_fn', type=str, default="tanh", help="Activation function for value network")
    parser.add_argument('--v_network_out_act_fn', type=str, default="identity", help="Output activation function for value network")
    
    # Agent policy network parameters
    parser.add_argument('--policy_network_params', type=list, default=[("mlp", 64), ("mlp", 64)], help="Network parameters for policy network")
    parser.add_argument('--policy_network_optimizer_class', type=str, default="Adam", help="Optimizer class for policy network")
    parser.add_argument('--policy_network_learning_rate', type=float, default=0.0003, help="Learning rate for policy network")
    parser.add_argument('--policy_network_act_fn', type=str, default="tanh", help="Activation function for policy network")
    parser.add_argument('--policy_network_out_act_fn', type=str, default="identity", help="Output activation function for policy network")
    parser.add_argument('--policy_network_re_parameterize', type=bool, default=False, help="Whether to re-parameterize")
    parser.add_argument('--policy_network_predicted_std', type=bool, default=False, help="Whether standard deviation is predicted")
    parser.add_argument('--policy_network_parameterized_std', type=bool, default=True, help="Whether standard deviation is parameterized")
    parser.add_argument('--policy_network_stablize_log_prob', type=bool, default=False, help="Whether to stabilize log probability")
    parser.add_argument("--use_tanh", type=float, default=False, help="Trick 10: tanh activation function")
    parser.add_argument("--use_gru", type=bool, default=True, help="Whether to use GRU")
    parser.add_argument("--use_orthogonal_init", type=bool, default=True, help="Trick 8: orthogonal initialization")
    
    # Trainer parameters
    parser.add_argument('--max_env_steps', type=int, default=500000, help="Maximum number of environment steps")
    parser.add_argument('--num_env_steps_per_epoch', type=int, default=2000, help="Number of environment steps per epoch")
    parser.add_argument('--max_trajectory_length', type=int, default=1000, help="Maximum trajectory length")
    parser.add_argument('--batch_size', type=int, default=64, help="Batch size for training")
    parser.add_argument('--eval_interval', type=int, default=2000, help="Evaluation interval")
    parser.add_argument('--num_eval_trajectories', type=int, default=5, help="Number of evaluation trajectories")
    parser.add_argument('--snapshot_interval', type=int, default=10000, help="Snapshot interval")
    parser.add_argument('--start_timestep', type=int, default=0, help="Start timestep")
    parser.add_argument('--save_video_demo_interval', type=int, default=-1, help="Save video demo interval")
    parser.add_argument('--log_interval', type=int, default=1, help="Log interval")
    
    # 找到最新的模型路径
    # 读取log/ppo/THP/所有的文件夹名字， 更具时间戳找到最新的文件夹
    model_path = os.path.join("logs", "ppo", "THP")
    model_path = os.path.join(model_path, sorted(os.listdir(model_path))[-1])

    parser.add_argument('--model_path', type=str, default=model_path, help="Log interval")
    args = parser.parse_args()
    
    return args

    # if torch.cuda.is_available():
    #         logging.info('GPU is available.')
    # else:
    #     logging.info('GPU is unavailable.')
    #     if args.device_type == 'gpu':
    #         raise ValueError("GPU is unavailable, "
    #                          "please set device_type = 'cpu'.")
    # if args.device_type == 'gpu':
    #     if args.device_ids:
    #         os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device_ids)
    #     device = torch.device('cuda')
    # else:
    #     device = torch.device('cpu')
    # args.device = device
  
parser_csp('data/0.txt')