import  utils.config as config
from utils.model import Enhanced_large_GCNv3,Simplified_GCN
import torch
import time
import numpy as np
from utils.evaluation_functions import get_genload_N1,get_load_mismatch_rate,get_Pgcost,get_clamp,r2_score_matrix,evs_score_matrix
from utils.violation import get_PQ_violation,get_branch_violation,get_V_violation
from utils.get_dV import get_dV
import math
import pickle
from torch_geometric.loader import DataLoader
import gc
import matplotlib.pyplot as plt





def test_model():
    config.test_type = True

    if(config.train_type == 'tea'):
        print("tea")
        path = config.pre_path
        model = Enhanced_large_GCNv3(num_node_features=9,hidden_dim = 128,output_dim=2*config.Nbus-1)
    else:
        print("std")
        path = config.path
        model = Simplified_GCN(num_node_features=9,hidden_dim = 32,output_dim=2*config.Nbus-1)

    
    Pred_Va, Pred_Vm, Pred_V = get_Pred_V(model,path)

    PG_violation_gen, QG_violation_gen, PG_violation, QG_violation, PQ_violation_num, PQG_violation_index = \
        get_performance(config.real_test_VA, config.real_test_VM, config.real_test_V)   # performance before post processing
    PG_violation_gen, QG_violation_gen, PG_violation, QG_violation, PQ_violation_num, PQG_violation_index = \
        get_performance(Pred_Va, Pred_Vm, Pred_V)   # performance before post processing

    

def load_model(model, path, device):
    # 直接加载模型的 state_dict（不再需要处理 'module.' 前缀）
    state_dict = torch.load(path, map_location=device)
    model.load_state_dict(state_dict)
    return model



def get_Pred_V(model,path):
    
    with open(config.pickle_test_dataset_path, 'rb') as f:
        test_dataset = pickle.load(f)

    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

    print("load traind model")

    print(path)
    model = load_model(model, path, config.device)
    model.eval()
    model.to(config.device)
    testSample = len(test_loader.dataset)
    print(f'testSampleNum :{testSample}')

    VA_prediction = torch.zeros((testSample, config.Nbus - 1))
    VM_prediction = torch.zeros((testSample, config.Nbus))
    VA_real = torch.zeros((testSample, config.Nbus - 1))
    VM_real = torch.zeros((testSample, config.Nbus))
    index = 0
    # 记录整个推理过程的开始时间
    total_start_time = time.time()
    with torch.no_grad():  # 禁用梯度计算
        for step,(test_x,test_y) in enumerate(test_loader):
            test_x = test_x.to(config.device)
            test_y = test_y.to(config.device)
            feat, V_pred = model(test_x)
            batch_size = V_pred.shape[0]
            VM_prediction[index:index+batch_size] = V_pred[:,0:config.Nbus]
            VA_prediction[index:index+batch_size] = V_pred[:,config.Nbus:]

            VM_real[index:index+batch_size] = test_y[:,0:config.Nbus]
            VA_real[index:index+batch_size] = test_y[:,config.Nbus:]

            index += batch_size

    # 计算整个训练过程的平均每个 epoch 时间
    total_end_time = time.time()
    total_reason_time = total_end_time - total_start_time

    print(f'Total reasoning time: {total_reason_time:.4f} seconds')

    # print(f"VA_prediction.shape:{VA_prediction.shape};VM_prediction.shape:{VM_prediction.shape}")
    VA_prediction = VA_prediction.cpu().detach()
    VM_prediction = VM_prediction.cpu().detach()


    # Va with slack bus
    Pred_Va = VA_prediction.clone().numpy()

    #重新判断此处应该是加0还是0.5
    slack_VA = (config.bus_slack_VA * math.pi / 180 - config.VaLb)/(config.VaUb - config.VaLb)
    Pred_Va = np.insert(Pred_Va, config.bus_slack, values=slack_VA, axis=1)
    Pred_Va = Pred_Va* (config.VaUb - config.VaLb) + config.VaLb

    # Vm
    Pred_Vm = VM_prediction.clone().numpy()
    Pred_Vm = Pred_Vm*(config.VmUb - config.VmLb) + config.VmLb
    #pred_tot
    Pred_V = Pred_Vm * np.exp(1j * Pred_Va)


    return Pred_Va,Pred_Vm,Pred_V


def get_performance(Pred_Va,Pred_Vm,Pred_V):
    Pred_Pg,Pred_Qg,Pred_Pd,Pred_Qd = get_genload_N1(Pred_V,config.test_Pd,config.test_Qd,config.bus_Pg,config.bus_Qg)
    Real_Pg,Real_Qg,Real_Pd,Real_Qd = get_genload_N1(config.real_test_V,config.test_Pd,config.test_Qd,config.bus_Pg,config.bus_Qg)
    # print(f'Pred_Pg.shape:{Pred_Pg.shape};Real_Pg.shape:{Real_Pg.shape}')
    PG_violation_ratio, QG_violation_ratio, PG_violation_gen, QG_violation_gen, PG_violation, QG_violation, PQ_violation_num, PQG_violation_index = get_PQ_violation(Pred_Pg, Pred_Qg)
    branch_violation = get_branch_violation(Pred_V)
    VM_violation_ratio, VA_violation_ratio, VM_violation_bus, \
    VA_violation_bus =  get_V_violation(Pred_Vm, Pred_Va)
    mre_Pd = get_load_mismatch_rate(torch.from_numpy(Real_Pd).sum(axis=1),torch.from_numpy(Pred_Pd).sum(axis=1))
    mre_Qd = get_load_mismatch_rate(torch.from_numpy(Real_Qd).sum(axis=1),torch.from_numpy(Pred_Qd).sum(axis=1))

    Pred_cost = get_Pgcost(Pred_Pg, config.idxPg, config.gencost)

    Real_cost = get_Pgcost(Real_Pg, config.idxPg, config.gencost)
    opt_gap = np.mean(np.divide((Pred_cost - Real_cost), Real_cost)) * 100

    r2_Va = evs_score_matrix(config.real_test_VA,Pred_Va)
    r2_Vm = evs_score_matrix(config.real_test_VM,Pred_Vm)
    r2_Pg = evs_score_matrix(Real_Pg,Pred_Pg)
    r2_Qg = evs_score_matrix(Real_Qg,Pred_Qg)


    print("The PG satisfaction rate is: {name}%".format(name=100 - PG_violation_ratio))
    print("The QG satisfaction rate is: {name}%".format(name=100 - QG_violation_ratio))
    print("The Branch satisfaction rate is: {name}%".format(name=100 - branch_violation))
    print("The VM satisfaction rate is: {name}%".format(name=100 - VM_violation_ratio))
    print("The VA satisfaction rate is: {name}%".format(name=100 - VA_violation_ratio))
    print("The Pd satisfaction rate is: {name}%".format(name=100 - torch.mean(mre_Pd)))
    print("The Qd satisfaction rate is: {name}%".format(name=100 - torch.mean(mre_Qd)))
    print("The optimality gap is: {name}%".format(name=(opt_gap)))


    print(f'The evs_Va:{r2_Va}')
    print(f'The evs_Vm:{r2_Vm}')
    print(f'The evs_Pg:{r2_Pg}')
    print(f'The evs_Qg:{r2_Qg}')

    return PG_violation_gen, QG_violation_gen, PG_violation, QG_violation, PQ_violation_num, PQG_violation_index





def plot_multiple_complex_voltage_data(data1, data2, data3=None, output_filename='voltage_scatter_plot.png'):

    # 拆解成实部和虚部
    real_part1, imag_part1 = data1.real, data1.imag
    real_part2, imag_part2 = data2.real, data2.imag
    if data3 is not None:
        real_part3, imag_part3 = data3.real, data3.imag
    else:
        real_part3, imag_part3 = np.array([]), np.array([])
    
    # 创建一个散点图，增加图形大小
    plt.figure(figsize=(12, 9))

    # 绘制三个数据的散点图，使用不同的颜色，并增加透明度
    plt.scatter(real_part1.flatten(), imag_part1.flatten(), s=15, alpha=0.5, c='r', label='ground-truth', edgecolors='black', linewidth=0.5)
    plt.scatter(real_part2.flatten(), imag_part2.flatten(), s=15, alpha=0.5, c='g', label='large-reGAF', edgecolors='black', linewidth=0.5)
    
    if data3 is not None:
        plt.scatter(real_part3.flatten(), imag_part3.flatten(), s=15, alpha=0.5, c='b', label='simple-GAF', edgecolors='black', linewidth=0.5)

    # 设置坐标轴的范围
    min_real = np.min([real_part1.min(), real_part2.min(), real_part3.min() if data3 is not None else np.inf])
    max_real = np.max([real_part1.max(), real_part2.max(), real_part3.max() if data3 is not None else -np.inf])
    min_imag = np.min([imag_part1.min(), imag_part2.min(), imag_part3.min() if data3 is not None else np.inf])
    max_imag = np.max([imag_part1.max(), imag_part2.max(), imag_part3.max() if data3 is not None else -np.inf])
    
    plt.xlim(min_real - 0.03, max_real + 0.03)
    plt.ylim(min_imag - 0.03, max_imag + 0.03)

    # 设置坐标轴标签
    plt.xlabel('Real Part', fontsize=12)
    plt.ylabel('Imaginary Part', fontsize=12)

    # 设置网格
    plt.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.7)

    # 添加图例
    plt.legend(loc='upper left', fontsize=12, frameon=True)

    # 保存图像到文件
    plt.tight_layout()
    plt.savefig(output_filename, dpi=300)

    print("Plot task completed successfully")
    print("********************************")

    # 返回保存的文件路径
    return output_filename




        
        
