from itertools import count
import gymnasium as gym
from gymnasium import spaces, core
from gymnasium.utils import seeding
import numpy as np
from numpy import random
import math
import time
from datetime import datetime
import pickle
import pandas as pd
import os
import random
import sys

import pandapower as pp
import pandapower.networks as pn
from pandapower import control

import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv, MessagePassing

new_step_api=True

def calc_q_capacity_pf(p_mw, pf_min=0.9):
    """
    根据额定有功容量和最小功率因数计算无功容量 (MVAr)

    参数:
    ----------
    p_mw : float
        电池额定有功容量 (MW)
    pf_min : float, optional
        最小功率因数 (默认 0.9)

    返回:
    ----------
    q_max : float
        最大无功容量 (MVAr)
    s_max : float
        逆变器额定视在容量 (MVA)
    """
    # 逆变器额定视在容量
    s_max = p_mw / pf_min
    # 最大无功能力
    q_max = math.sqrt(max(s_max**2 - p_mw**2, 0))
    return q_max, s_max
        

class CBEnv(core.Env):
    def __init__(self, args) -> None:
        super(CBEnv).__init__()
        
        self.args = args
        self.seed()
        
        self.cb_bus_id = {0:7, 1:11, 2:15, 3:30}
        self.hop = args.hop # 这里可以理解为有无GNN，如果要研究部分图，这个hop可以用于邻居定义

        ##### static variables #####
        self.num_cb = args.num_cb   # Number of Charging station
        self.agent_num = self.num_cb   # Number of Charging station
        self.num_unit = args.num_unit # 以防后续会考虑一个场地内的多个cb，这里保留，暂设为1
        # 有功容量
        self.battery_cap_cb = 500
        self.battery_lower_cb = 0
        
        # 无功容量，视在容量
        self.max_q_mvar, self.s_max = calc_q_capacity_pf(self.battery_cap_cb/1000)
        self.min_q_mvar = 0
        
        self.battery_soc_cb = 0 # initial soc for CB
        self.max_power_cb = 100 # 电池的充放功率
        self.soc_upper = 1.0
        self.soc_lower = 0.0
        
        
        # solar data is 4000 kwp; real solar is 4000/60 kwp
        self.solar_zoom = 60
        self.max_purchase = 10000
        self.max_sell = 10000
        self.max_pv_generation = 1000
        self.load_max = 1000
        self.time_step = 5 # 每5分钟算作一个timestep
        self.truncated_length = args.max_steps * self.time_step # ep length


        ##### dynamic variables #####
        self.state = []
        self.action = []
        self.reward = 0
        self.current_pv = [] # 这里记录所有bus的
        self.current_load = []
        self.current_sell = 0
        self.current_purchase = 0
        self.time = 0 # current minute
        self.init_net_power = [0] * self.num_cb # for runpp
        self.init_q = [0] * self.num_cb
        self.bus_volt = []
        
        self.bus_pv = []
        self.bus_load = []
        self.soc = []
                
        # For bus voltage
        self.voltage_penalty = 1
        self.vol_low = 1 - args.vol_th
        self.vol_high = 1 + args.vol_th
        
        self.degradation_norm_cap = self.battery_cap_cb  # 用容量归一化：kWh/kWh -> 等效循环
        self.pv_fair_use_variance = True                 # True=方差，False=标准差

        # profile duration
        self.start_index = 0 * args.max_steps
        self.end_index = 5000 * args.max_steps

        ######### DN 设计，包含了partial obs的考虑，但是不影响整体的环境设置 #########
        # 生成DN
        self.bus_ids = args.bus_ids
        self.neighbor_map = {}

        self._cursor = 0
        
        self.graph_input_dim = 4 # v theta p q

        # 初始化网络并读取相应的环境信息
        self.network = self.init_network(self.bus_ids, args)
        self.fix_load = np.array(self.network.load['p_mw'].tolist())
        self.grid_operate(self.init_net_power, self.init_q, self.network, self.fix_load,
                          np.array([0.]*len(self.fix_load)), np.array([0.]*self.num_cb))
        self.total_buses = self.network.bus.index.tolist()
        
        # self.purchase_price_data = pickle.load(open('/root/omni/safe_EV/files/TOU_price', 'rb'))[self.start_index:self.end_index]
        self.sell_price_data = pickle.load(open('/root/omni/safe_EV/files/wholesale_price', 'rb'))      

        pv_df = pd.read_csv("/root/omni/safe_EV/files/ausgrid_pv.csv").drop(columns=["191"], errors="ignore")
        pv_win = pv_df.iloc[self.start_index:self.end_index]
        house_cols = pv_win.columns[1:]
        k = len(self.total_buses)
        top_cols = pv_win[house_cols].mean(axis=0).nlargest(k).index.tolist()
        self.pv_data = pv_win.loc[:, top_cols]
        #### 进行放大 ####
        self.pv_data = self.pv_data * 1.5

        load_df = pd.read_csv("/root/omni/safe_EV/files/ausgrid_load.csv").drop(columns=["191"], errors="ignore")
        load_win = load_df.iloc[self.start_index:self.end_index]
        self.load_data = load_win.loc[:, top_cols]
        
        ev_df = pd.read_csv("/root/urban_EV_data/urban_EV_data_bus_ev_power_5min.csv").iloc[:, 1:]
        ev_df = ev_df[ev_df.mean().sort_values().index] # 排序
        # 放缩到可解的范围
        def rescale_df(df, lower=100, upper=1000):
            df_scaled = df.copy()
            for col in df.columns:
                col_min, col_max = df[col].min(), df[col].max()
                if col_max > col_min:  # 避免除零
                    df_scaled[col] = (df[col] - col_min) / (col_max - col_min) * (upper - lower) + lower
                else:
                    df_scaled[col] = (upper + lower) / 2  # 如果整列都是常数，就直接给个中间值
            return df_scaled
        
        ev_df = rescale_df(ev_df, 100, 1000)

        ev_win = ev_df.iloc[self.start_index:self.end_index]
        self.ev_data = ev_win.iloc[:, :len(self.total_buses)-1]
        self.ev_data = self.ev_data / 1000
        

        self.sell_price_data = np.array(self.sell_price_data)/300
        self.sell_price_data[self.sell_price_data > 2] = 0
        self.sell_price_data[self.sell_price_data < -0.8] = 0
        self.sell_price_data = self.sell_price_data.tolist()

        ##### 但在我们这个场景中，暂时考虑全局信息，即，认为CB是归DNSP管，neighborhood设置为全部bus #####
        ##### 只要保证self.neighborhood是对的，整个环境都是对的 #####
        self.neighborhood = [[] for _ in range(self.num_cb)]
        for i in range(self.num_cb):
            self.neighborhood[i] += self.total_buses
        self.adjacency_matrix = [[] for _ in range(self.num_cb)]
        self.init_adjacency_matrix()


        ##### 文件读取的全局index，是随着step的分钟数而增加的 #####
        
        # self.purchase_price_data_ep = self.purchase_price_data[int(self.time/5):]
        self.sell_price_data_ep = self.sell_price_data[int(self.time/5):]
        
        self.pv_data_ep = self.pv_data
        self.load_data_ep = self.load_data
        self.ev_data_ep = self.ev_data
            

        ##### 环境标准变量定义 #####
        # 此处是针对每个station来定义，而非所有的station
        
        # 除非要考虑CB内部的管理，否则不用专门获取每个CB的信息
        # 直接获取每个节点的 p, q, PV, V, soc (if any), price buy, price sell
        # graph info = [V, load, PV)]，先直接flatten的方式提供，用超参数hop决定是否要用GNN
        # self.CB_info_len = 2
        self.bus_info_len = 5
        self.price_info = 1
        self.state_high = [10000, 10000, 10000, self.vol_high, self.soc_upper]
        self.state_low = [-10000, -10000, -10000, self.vol_low, self.soc_lower]
        
        self.obs_dim = (len(self.total_buses)-1) * self.bus_info_len + self.price_info
        
        high_list = (len(self.total_buses)-1) * self.state_high + [self.max_sell]
        high = np.array(high_list, np.float32).flatten()
        low_list = (len(self.total_buses)-1) * self.state_low + [-self.max_sell] 
        low = np.array(low_list, np.float32)
        
        self.observation_space = spaces.Box(low=low, high=high, shape=(self.obs_dim, )) # state space


        # c/d 的值; Q的值；在c时买多少电，在d的时候卖多少电，ratio；每个节点的PV curtailment的ratio
        self.action_dim = self.num_unit * self.num_cb + len(self.total_buses) - 1
        action_high = np.array([1.0, 1.0, 1.0]*self.num_cb + [1.0] * (len(self.total_buses)-1), np.float32)
        action_low = np.array([-1.0, -1.0, 0.0]*self.num_cb + [0.0] * (len(self.total_buses)-1), np.float32)
        self.action_space = spaces.Box(low = action_low, high = action_high, shape=(self.action_dim, ), dtype = np.float32)
        
        
        
    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]
    
    # calculate voltage violation number
    def cost_num(self, bus_volt):
        volt_vio = 0
        for x in bus_volt:
            if x < self.vol_low or x > self.vol_high:
                volt_vio = volt_vio + 1
        return volt_vio
    
    # calculate voltage violation degree
    def cost_degree(self, bus_volt):
        volt_vio = 0
        for x in bus_volt:
            if x < self.vol_low:
                volt_vio = volt_vio + (self.vol_low - x)
            if x > self.vol_high:
                volt_vio = volt_vio + (x - self.vol_high)
        return volt_vio

    def init_network(self, bus_ids, args):
        start = time.time()
        network = pp.from_excel("/root/omni/safe_EV/files/33_bus.xlsx")
        # 放电池
        for bus_id in bus_ids:
            pp.create_storage(network, bus=bus_id, p_mw=0.000,
                              max_e_mwh=self.battery_cap_cb/1000,
                              min_e_mwh=self.battery_lower_cb/1000,
                              max_p_mw=self.max_power_cb/1000,
                              min_p_mw=-self.max_power_cb/1000,
                              max_q_mvar=self.max_q_mvar,
                              min_q_mvar=self.min_q_mvar,
                              soc_percent=self.battery_soc_cb/self.battery_cap_cb,
                              controllable=True)
            
        # 放PV，第0个bus上没有load，也不应该有PV
        for bus in range(len(network.load)):
            pp.create_sgen(
                network,
                bus=bus+1,
                p_mw=0.0,         # 初始时不发电，运行中动态更新
                q_mvar=0.0,       # 可设为 0 或支持无功发电
                name=f"PV{bus+1}"
            )
            
        # 修改线路长度来影响电压
        for i in range(len(network.line['length_km'])):
            network.line['length_km'][i] = args.line_length
        
        return network


    # 先运行，然后写函数来获取各种信息
    def grid_operate(self, net_power, q, network, load, pv, soc):
        start = time.time()
        network.load['p_mw'] = load
        network.sgen['p_mw'] = pv
        for i in range(len(net_power)):
            network.storage.at[i, "p_mw"] = net_power[i]
            network.storage.at[i, "q_mvar"] = q[i]
            network.storage.at[i, "soc_percent"] = soc[i]
        pp.runpp(network, trafo_loading='power', algorithm='fdbx',calculate_voltage_angles=False,
                check_connectivity=False, numba=True, tolerance_mva=1e-4)
        # print('operate', time.time() - start)

    ####### 图信息制作 #######

    # 获取所有单相bus的voltage
    def get_bus_vol(self, network):
        return network.res_bus['vm_pu'].to_numpy()

    def get_bus_p(self, network):
        return network.res_load['p_mw'].to_numpy()

    def get_bus_q(self, network):
        return network.res_load['q_mvar'].to_numpy()
    
    def get_bus_pv(self, network):
        return network.res_sgen['p_mw'].to_numpy()
    
    def get_bus_soc(self, network):
        bus_soc = [0] * (len(self.total_buses)-1)
        for i in range(1, len(self.total_buses)):
            if i in self.cb_bus_id.values():
                key = [k for k, v in self.cb_bus_id.items() if v==i][0]
                bus_soc[i] = network.storage.at[key, 'soc_percent']
        return bus_soc

    # 获取bus的静态属性和动态特征
    # 暂时没有静态属性
    def get_bus_feature(self, network, bus_id):
        vol = self.get_bus_vol(network)[bus_id]
        p = self.get_bus_p(network)[bus_id]
        q = self.get_bus_q(network)[bus_id]
        pv = self.get_bus_pv(network)[bus_id]
        soc = self.get_bus_soc(network)[bus_id]

        info = [p, q, pv, vol, soc]
        return info

    def get_global_edge_index(self, network, global_neighborhood, args):
        total_edge_index = np.array([network.line['from_bus'].tolist(), network.line['to_bus'].tolist()])
        mask = np.isin(total_edge_index[0, :], global_neighborhood) & np.isin(total_edge_index[1, :], global_neighborhood)
        sub_edge_index = total_edge_index[:, mask]
        sorted_edges = np.sort(sub_edge_index, axis=0)
        unique_edges = np.unique(sorted_edges, axis=1)

        return unique_edges


    def map_edge_index(self, edge_index, center_node=8):
        # 获取所有唯一节点
        unique_nodes = np.unique(edge_index).tolist()
        
        # 如果存在中心节点，则移除并放在最前面
        if center_node in unique_nodes:
            unique_nodes.remove(center_node)
            unique_nodes = sorted(unique_nodes)
            local_nodes = [center_node] + unique_nodes
        else:
            local_nodes = sorted(unique_nodes)
            
        # 构造映射字典：全局 id -> 局部 id
        mapping = {global_id: local_id for local_id, global_id in enumerate(local_nodes)}
        
        # 复制 edge_index 并进行映射
        new_edge_index = edge_index.copy()
        for global_id, local_id in mapping.items():
            new_edge_index[new_edge_index == global_id] = local_id
        return new_edge_index, mapping


    #### 单相line dynamic info总共就14个 ####
    def get_line_info(self, network, bus_from, bus_to):
        line_idx = network.line[(network.line['from_bus'] == bus_from) & (network.line['to_bus'] == bus_to)].index
        if len(line_idx) == 0:
            line_idx = network.line[(network.line['from_bus'] == bus_to) & (network.line['to_bus'] == bus_from)].index
        
        if len(line_idx) > 0:
            idx = line_idx[0]
            line_info = [
                network.res_line.at[idx, 'p_from_mw'],  # 有功功率从起点流向终点 (MW)
                network.res_line.at[idx, 'p_to_mw'],  # 有功功率从终点流向起点 (MW)
                network.res_line.at[idx, 'q_from_mvar'],  # 无功功率从起点流向终点 (MVAr)
                network.res_line.at[idx, 'q_to_mvar'],  # 无功功率从终点流向起点 (MVAr)
                network.res_line.at[idx, 'pl_mw'],  # 有功功率损耗 (MW)
                network.res_line.at[idx, 'ql_mvar'],  # 无功功率损耗 (MVAr)
                network.res_line.at[idx, 'i_from_ka'],  
                network.res_line.at[idx, 'i_to_ka'],  
                network.res_line.at[idx, 'i_ka'],  # 电流 (kA)
                network.res_line.at[idx, 'vm_from_pu'],  
                network.res_line.at[idx, 'vm_to_pu'],  
                network.res_line.at[idx, 'va_from_degree'],  
                network.res_line.at[idx, 'va_to_degree'],  
                network.res_line.at[idx, 'loading_percent']  # 负载率 (%)
            ]
            return line_info
        else:
            return [0]*14
    
    
    
    # 构建station连接信息的邻接矩阵
    def init_adjacency_matrix(self):
        # 对每一个evcs bus
        for i, bus_id in enumerate(self.bus_ids):
            num_nodes = len(self.neighborhood[i])
            adjacency_matrix = np.zeros((num_nodes, num_nodes))
            neighbors = self.neighborhood[i]  # 获取邻居bus的ID列表
            # 对每一个evcs bus的neighbor
            for j, neighbor_bus_id in enumerate(neighbors):
                adjacency_matrix[j][j] = 1
                tmp = [x for x in neighbors if x != neighbor_bus_id]
                for neighbor_bus_id_ in tmp:
                    if ((neighbor_bus_id, neighbor_bus_id_) in list(zip(self.network.line.from_bus, self.network.line.to_bus))) or ((neighbor_bus_id, neighbor_bus_id_) in list(zip(self.network.line.to_bus, self.network.line.from_bus))):
                        adjacency_matrix[j][neighbors.index(neighbor_bus_id_)] = 1
                        adjacency_matrix[neighbors.index(neighbor_bus_id_)][j] = 1
            self.adjacency_matrix[i] = adjacency_matrix

    def step(self, action):
        truncated = False
        economic = [0] * self.num_cb
        
        self.state = []
        
        interval = self.time_step/60

        self.action = action
        bus_volt_vio = 0 # 这里是总的

        # 由于reset已经初始化了state，所以这里直接开始step，后面再更新state
        # self.current_load
        # self.current_pv
        # 所有节点的load 和 pv
        pv_gen = np.array(self.pv_data_ep.iloc[0, :-1].tolist()) * interval
        bus_load = np.array(self.load_data_ep.iloc[0, :-1].tolist()) * interval
        ev_load = np.array(self.ev_data_ep.iloc[0, :].tolist()) * interval
        curt_vec = np.array(self.action[-len(self.total_buses)+1:], dtype=np.float32)
        curt_vec = np.clip(curt_vec, 0.0, 1.0)

        pv_gen = pv_gen * (1-curt_vec)
        bus_load = bus_load + ev_load
        
        
        charging_amount = []
        trading_amount = []
        self.trading_amount = 0
        mw_power = []
        soc_percents = []
        q = []

        for i in range(self.num_cb):
            # 先算出本来应该的能量变化
            delta_energy = self.action[i*3] * self.max_power_cb * interval
            soc_next = self.network.storage.iloc[i]["soc_percent"] + delta_energy / self.battery_cap_cb

            # 如果会超上限
            if soc_next > self.soc_upper:
                # 能充的最大量
                allowed_delta = (self.soc_upper - self.network.storage.iloc[i]["soc_percent"]) * self.battery_cap_cb
                delta_energy = min(delta_energy, allowed_delta)
                # 回推修正 action
                self.action[i*3] = delta_energy / (self.max_power_cb * interval)

            # 如果会超下限
            if soc_next < self.soc_lower:
                allowed_delta = (self.soc_lower - self.network.storage.iloc[i]["soc_percent"]) * self.battery_cap_cb
                delta_energy = max(delta_energy, allowed_delta)
                self.action[i*3] = delta_energy / (self.max_power_cb * interval)

            # 计算量
            charging_amount.append(delta_energy)
            trading_amount.append(delta_energy * self.action[i*3+2])  # 保持和原逻辑一致
            mw_power.append(delta_energy / 1000)
            soc_percents.append(self.network.storage.iloc[i]["soc_percent"] + delta_energy / self.battery_cap_cb)
            q.append(self.action[i*3+1]*self.max_q_mvar)
        # 这里记得把q也放进去
        self.grid_operate(mw_power, q, self.network, bus_load, pv_gen, soc_percents)
        
        self.bus_load = bus_load
        self.bus_pv = pv_gen
        
        
        # In each minute, count 5 minutes as a step
        # 每分钟会变化的是：没有EV的话，则没有
        for t in range(self.time_step):
            self.time = self.time + 1
        
        # 计算电池的盈利
        # 每个电池充多少，其中多少是买的，要算钱。
        # 每个电池放多少，其中多少是卖的
        # 这里的价格是kw
        for i in range(self.num_cb):
            if trading_amount[i] > 0:
                economic[i] = trading_amount[i] * self.current_sell #self.current_purchase
            else:
                economic[i] = trading_amount[i] * self.current_sell
        
        self.trading_amount = trading_amount
        
        bus_volt = self.get_bus_vol(self.network)
        self.bus_volt = bus_volt

        interval = self.time_step / 60.0
        E_step_max_cb = self.max_power_cb * interval                       # kWh
        E_step_max_total = max(self.num_cb * E_step_max_cb, 1e-8)          # 防零
        batt_deg_cost = float(np.sum(np.abs(charging_amount)) / E_step_max_total)
        batt_deg_cost = float(np.clip(batt_deg_cost, 0.0, 1.0))
        
        if curt_vec.size > 0:
            var = float(np.var(curt_vec))                                  # ∈ [0, 0.25]
            pv_curt_fair_cost = float(np.clip(var / 0.25, 0.0, 1.0))
        else:
            pv_curt_fair_cost = 0.0


        # 计算电压违规和线路负载
        # 这里可能需要改成和违反数值有关的，不单单是个数
        bus_volt_vio_num = self.cost_num(self.bus_volt)
        bus_volt_vio_degree = self.cost_degree(self.bus_volt)
        bus_volt_vio = bus_volt_vio_num + self.voltage_penalty * bus_volt_vio_degree
        line_vio = np.sum(self.network.res_line.loading_percent)
        self.reward = -np.sum(economic)

        # 装载cost等信息到info
        info = {}
        info['voltage'] = bus_volt_vio
        info['voltage_num'] = bus_volt_vio_num/len(self.total_buses)
        info['voltage_degree'] = bus_volt_vio_degree
        info['thermal'] = line_vio
        info['batt_deg'] = batt_deg_cost
        info['pv_curt_fair'] = pv_curt_fair_cost
        
        if self.time % self.truncated_length == 0:
            truncated = True
        
        
        # 更新state
        for i in range(1, len(self.total_buses)):
            self.state += self.get_bus_feature(self.network, i-1)
        # self.current_purchase = self.purchase_price_data_ep[0]
        self.current_sell = self.sell_price_data_ep[0]
        # self.state += [self.current_purchase]
        self.state += [self.current_sell]
        del self.sell_price_data_ep[0]
        # del self.purchase_price_data_ep[0]

        self.soc.append(soc_percents)

        self.current_pv.append(np.array(self.pv_data_ep.iloc[0,:-1].tolist()))
        self.current_load.append(np.array(self.load_data_ep.iloc[0,:-1].tolist()))
        self.pv_data_ep = self.pv_data_ep.drop(index=self.pv_data_ep.index[0])
        self.load_data_ep = self.load_data_ep.drop(index=self.load_data_ep.index[0])
        
        return [np.array(self.state, dtype='float32'), self.reward, truncated, info]

    
    def reset(self):
        # 重置所有动态变量
        self.state = []
        self.action = []
        self.reward = 0
        self.current_pv = []
        self.current_load = []
        self.current_sell = 0
        self.current_purchase = 0
        self.soc = []
        self.bus_pv = []
        self.bus_load = []
        self.time = 0

        self._cursor += 288
        start = int(self._cursor)
        
        # 重置网络
        self.network = self.init_network(self.bus_ids, self.args)
        self.grid_operate(self.init_net_power, self.init_q, self.network, self.fix_load,
                          np.array([0.]*len(self.fix_load)), np.array([0.]*self.num_cb))
        
        info = {}
        
        # 重置每个电池的信息
        for i in range(1, len(self.total_buses)):
            self.state += self.get_bus_feature(self.network, i-1)
        # 重置价格
        # current_purchase_price = self.purchase_price_data[0]
        current_sell_price = self.sell_price_data[0]
        # self.state += [current_purchase_price]
        self.state += [current_sell_price]
        
        # 设置新episode中用于读取的文件段
        # self.purchase_price_data_ep = self.purchase_price_data[self.time:]
        self.sell_price_data_ep = self.sell_price_data[start:start+288]
        self.pv_data_ep  = self.pv_data.iloc[start:start+288, :]
        self.load_data_ep= self.load_data.iloc[start:start+288, :]
        self.ev_data_ep  = self.ev_data.iloc[start:start+288, :]
        
        return np.array(self.state, dtype='float32')

    def close(self):
        pass
