import torch
import torch.nn as nn
from torch.distributions import Normal
import torch.optim as optim
import numpy as np
import copy
import argparse

import torch
import torch.nn as nn
import argparse

class Args:
    def __init__(self):
        parser = argparse.ArgumentParser(description="Training hyperparameters and environment settings")

        # 方法和模型选择相关
        parser.add_argument('--method', type=str, default='mappo', help='DRL algorithm')
        parser.add_argument('--comm', type=lambda x: (str(x).lower()=='true'), default=False,
                            help='if comm existing among agents in MARL')

        # 模型超参数
        parser.add_argument('--lr', type=float, default=3e-4, help='learning rate')
        parser.add_argument('--hidden_dim', type=int, default=32, help='hidden_dim')
        parser.add_argument('--gamma', type=float, default=0.99, help='discount factor')
        parser.add_argument('--epsilon', type=float, default=0.2, help='PPO clip epsilon')
        parser.add_argument('--sigma', type=float, default=0.5, help='PPO exploration rate')
        parser.add_argument('--sigma_decay', type=float, default=0.997, help='PPO exp rate decay')
        parser.add_argument('--K_epochs', type=int, default=3, help='update epochs')
        parser.add_argument('--num_agents', type=int, default=1, help='agent number')
        parser.add_argument('--buffer_size', type=int, default=100000, help='replay buffer size')
        parser.add_argument('--batch_size', type=int, default=64, help='update batch size')
        parser.add_argument('--tau', type=float, default=0.01, help='soft update coef')
        parser.add_argument('--exploration_noise', type=float, default=0.1, help='Gaussian noise')

        # 训练参数
        parser.add_argument('--max_episodes', type=int, default=1000, help='training eps')
        parser.add_argument('--max_steps', type=int, default=288, help='max steps of each PPO ep')
        parser.add_argument('--update_step', type=int, default=288*3, help='update steps of other algs')

        # 环境和模型初始化参数
        parser.add_argument('--bus_ids', default=[7, 11, 15, 30], help='Bus IDs of EVCS')
        parser.add_argument('--vol_th', type=float, default=0.1, help='voltage threshold')
        parser.add_argument('--num_cb', type=int, default=4, help='number of charging station')
        parser.add_argument('--num_unit', type=int, default=3, help='number of controllable units of one agent')
        parser.add_argument('--staying_length', type=str, default='short', help='charging duration length')
        parser.add_argument('--line_length', type=float, default=1.5, help='bus system line length (for voltage varying sensitivity)')
        parser.add_argument('--window_length', default=288, help='Historical data length for STGAT')
        
        # 图
        parser.add_argument('--hop', type=int, default=0, help='neighboring hop')
        parser.add_argument('--entire', type=lambda x: (str(x)).lower()=='true',
                            default=True, help='if get lines between neignbors')

        # 测试参数
        parser.add_argument('--num_test_ep', type=int, default=100, help='test eps')
        parser.add_argument('--test', action='store_true', default=False, help='if the env is in test mode')

        # 拉格朗日乘子和成本约束
        parser.add_argument('--lambda_lr', type=float, default=0.01, help='Lagrangian lr')
        parser.add_argument('--cost_threshold', type=float, default=0.1, help='adjusting factor for policy exploration')
        parser.add_argument('--cost_threshold_volt', type=float, default=-0.5, help='adjusting factor for policy exploration')
        parser.add_argument('--cost_threshold_demand', type=float, default=0.5, help='adjusting factor for policy exploration')
        parser.add_argument('--lambda_coef', type=float, default=0.5, help='Lagrangian coef')

        # args = parser.parse_args()
        args, _ = parser.parse_known_args()

        # 设置设备
        args.device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {args.device}")

        # 将解析得到的参数更新到类的属性中
        self.__dict__.update(vars(args))
        
        self.loss_fn = nn.MSELoss()
        
