import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from functools import reduce
import json, operator
# JK
import random
import os
from ODENet import SwingNN
from torchdiffeq import odeint
import time
import math

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# param = {'w_s': 2*np.pi*60,

#         # 'M':         [23.64/(np.pi*60), 6.4/(np.pi*60), 3.01/(np.pi*60)2
#          'M':         [47.28, 12.8, 6.2],
#         # 'D':         [0., 0., 0.],
#         # 'D':         [0.15, 0.15, 0.1
#          'D':         [2, 2, 2],   #From thesis
#          'X_d_prime': [2*0.270, 3.5*0.209, 3.75*0.304],
#          'X_q_prime': [0.470, 0.850, 0.5795],
#          'H':         [2.6309, 5.078, 1.200],
#         }

param = {'w_s': 2*np.pi*60,

        # 'M':         [23.64/(np.pi*60), 6.4/(np.pi*60), 3.01/(np.pi*60)2
         'M':         [47.28, 12.8, 6.2],
        # 'D':         [0., 0., 0.],
        # 'D':         [0.15, 0.15, 0.1
         'D':         [2, 2, 2],   #From thesis
         'X_d_prime': [0.270, 0.209, 0.304],
         'X_q_prime': [0.470, 0.850, 0.5795],
         'H':         [2.6309, 5.078, 1.200],
        }

def get_solution(j, y_zero, t):
    def model_ode(t, U):
        return torch.tensor([param['w_s']*(U[1].detach().numpy()), 1/(2 * param['H'][0])*(Pm_[j] -param['D'][j] * 2 * np.pi * U[1].detach().numpy() - U[3].detach().numpy() * 1/param['X_d_prime'][0] * np.sin(U[0].detach().numpy() - U[2].detach().numpy())), 0, 0])            

        #return torch.tensor([param['w_s']*(U[1].cpu().detach().numpy()), 1/(2 * param['H'][0])*(Pm_[j] -param['D'][0] * 2 * np.pi * U[1].cpu().detach().numpy() - U[3].cpu().detach().numpy() * 1/param['X_d_prime'][0] * np.sin(U[0].cpu().detach().numpy() - U[2].cpu().detach().numpy())), 0, 0]).to(DEVICE)
    sol = odeint(model_ode, y_zero.squeeze(), t.squeeze(), method = numerical_method).to(DEVICE)
    return sol

#Pm_ = [0.8984, 1.035, .8]
#Pm_ = [0.8984, 1.035, 0]
Pm_ = [2.45, 8.236489137147354/3, 1.3203286726306795]
#Pm_ = [2.45, 2.45, 2.45]

numerical_method = 'dopri5'

class ACOPFProblem:
    def __init__(self, filepath, acopf_name, device, args, valid_frac=0.1666/2, test_frac=0.1666/2, obj_scaler=1e0): #test_frac=0.1666/2
        filepath = Path(filepath)
        self.acopf_name = acopf_name
        self.device = device
        self.valid_frac = valid_frac
        self.test_frac = test_frac
        self.train_frac = 1.-valid_frac-test_frac

        # load network file
        network = json.load(open(filepath/"../network.json"))
        self.nbus = len(network["bus"])
        self.ngen = len(network["gen"])
        self.nload = len(network["load"])
        self.nbranch = len(network["branch"])
        self.nshunt = len(network["shunt"])
        self.loadids = np.sort(np.array(list(network["load"].keys()),dtype=np.int64))
        self.genids = np.sort(np.array(list(network["gen"].keys()),dtype=np.int64))
        self.busids = np.sort(np.array(list(network["bus"].keys()),dtype=np.int64))
        self.branchids = np.sort(np.array(list(network["branch"].keys()),dtype=np.int64))
        self.shuntids = np.sort(np.array(list(network["shunt"].keys()),dtype=np.int64)) # note that id starts from 1 typically

        # look up slack bus
        self.slack_bus_idx = []
        for bus_id, bus_data in network["bus"].items():
            if bus_data['bus_type'] == 3:
                bus_idx = np.where(self.busids==int(bus_id))[0][0]
                self.slack_bus_idx.append(bus_idx)
        assert len(self.slack_bus_idx) == 1
        self.slack_bus_idx = self.slack_bus_idx[0]

        self.baseMVA = network["baseMVA"]
        self.obj_scaler = obj_scaler

        #self.mapping = args['acopf_feature_mapping_type']

        #datafolder = args['acopf_feature_mapping_type']
        datapath = filepath #/ datafolder

        datafiles = list(datapath.glob("*.json"))
        if '%' in datapath.name:
            self.ndata = int(len(datafiles)*1)
        else:
            self.ndata = int(len(datafiles)*0.12)

        self.quad_cost, self.lin_cost, self.const_cost = torch.zeros(self.ngen, device=self.device), torch.zeros(self.ngen, device=self.device), torch.zeros(self.ngen, device=self.device)

        self.pgmin, self.pgmax, self.qgmin, self.qgmax = [], [], [], []
        self.gen2bus = [] # gen index to bus index

        # gen

        for i, id in enumerate(self.genids):
            gen_data = network["gen"][str(id)]
            cost_list = gen_data["cost"]
            for idx, c in enumerate(cost_list[::-1]):
                if idx == 0:
                    self.const_cost[i] = c
                elif idx == 1:
                    self.lin_cost[i] = c
                elif idx == 2:
                    self.quad_cost[i] = c
                else:
                    assert False
            self.pgmin.append(gen_data["pmin"]);
            self.qgmin.append(gen_data["qmin"]);
            self.pgmax.append(gen_data["pmax"]);
            self.qgmax.append(gen_data["qmax"])
            bus_id = gen_data["gen_bus"]
            bus_idx = np.where(self.busids == int(bus_id))[0][0]
            self.gen2bus.append(bus_idx)
        
        self.pgmin = torch.tensor(self.pgmin).to(self.device)
        self.pgmax = torch.tensor(self.pgmax).to(self.device)
        self.qgmin = torch.tensor(self.qgmin).to(self.device)
        self.qgmax = torch.tensor(self.qgmax).to(self.device)

        # bus
        self.vmmax, self.vmmin = [], []
        self.basekv = []
        for i,id in enumerate(self.busids):
            bus_data = network["bus"][str(id)]
            self.vmmax.append(bus_data["vmax"]); self.vmmin.append(bus_data["vmin"])
            self.basekv.append(bus_data["base_kv"])

        #self.vmmax = 1.03*torch.tensor(self.vmmax).to(self.device)
        self.vmmax = 1.03*torch.tensor(self.vmmax).to(self.device)

        self.vmmin = torch.tensor(self.vmmin).to(self.device)
        self.basekv = torch.tensor(self.basekv).to(self.device)

        # setup bus_genidxs
        bus_genidxs = []
        max_ngens = 0 # max num of generators at bus
        for i,id in enumerate(self.busids):
            genidxs = []
            for gen_idx, bus_idx in enumerate(self.gen2bus):
                if i==bus_idx:
                    genidxs.append(gen_idx)
            bus_genidxs.append(genidxs)
            max_ngens = max(max_ngens, len(genidxs))

        self.bus_genidxs = self.ngen*torch.ones(self.nbus,max_ngens, dtype=torch.int64).to(device)
        for i, genidxs in enumerate(bus_genidxs):
            for j,genidx in enumerate(genidxs):
                self.bus_genidxs[i,j] = genidx

        # load
        self.load2bus = [] # load idx to bus idx
        for i,id in enumerate(self.loadids):
            load_data = network["load"][str(id)]
            bus_id = load_data["load_bus"]
            bus_idx = np.where(self.busids==int(bus_id))[0][0]
            self.load2bus.append(bus_idx)

        # branch
        br_r, br_x = [], []
        tap, shift = [], [] # related to the transformer
        self.g_to, self.g_fr = [], []
        self.b_to, self.b_fr = [], []
        self.angmin, self.angmax = [], []
        self.bus_i, self.bus_j = [], [] # bus_i == f_bus | bus_j == t_bus || branch_out_per_bus==bus_i | branch_in_per_bus==bus_j #busidx (not id)
        self.thermal_limit = []
        self.edges = [] # i, j, and branch_id

        for i,id in enumerate(self.branchids):
            branch_data = network["branch"][str(id)]
            br_r.append(branch_data["br_r"]); br_x.append(branch_data['br_x'])
            tap.append(branch_data["tap"]); shift.append(branch_data["shift"])
            self.g_to.append(branch_data["g_to"]); self.g_fr.append(branch_data["g_fr"])
            self.b_to.append(branch_data["b_to"]); self.b_fr.append(branch_data["b_fr"])
            self.angmin.append(branch_data["angmin"]); self.angmax.append(branch_data["angmax"])
            self.thermal_limit.append(branch_data["rate_a"]) # only rate_a is considered in AC-OPF
            bus_i_id = branch_data["f_bus"]; bus_j_id = branch_data["t_bus"]
            bus_i_idx = np.where(self.busids==int(bus_i_id))[0][0]; bus_j_idx = np.where(self.busids==int(bus_j_id))[0][0]
            self.bus_i.append(bus_i_idx); self.bus_j.append(bus_j_idx)
            self.edges.append((self.bus_i[-1], self.bus_j[-1], {"idx": i}))

        br_r = torch.tensor(br_r).to(self.device)
        br_x = torch.tensor(br_x).to(self.device)
        tap = torch.tensor(tap).to(self.device)
        shift = torch.tensor(shift).to(self.device)
        self.g_to = torch.tensor(self.g_to).to(self.device)
        self.g_fr = torch.tensor(self.g_fr).to(self.device)
        self.b_to = torch.tensor(self.b_to).to(self.device)
        self.b_fr = torch.tensor(self.b_fr).to(self.device)
        self.angmin = torch.tensor(self.angmin).to(self.device) # radian
        self.angmax = torch.tensor(self.angmax).to(self.device) # radian
        self.bus_i = torch.tensor(self.bus_i).to(self.device)
        self.bus_j = torch.tensor(self.bus_j).to(self.device)
        self.thermal_limit = torch.tensor(self.thermal_limit).to(self.device)

        assert self.thermal_limit.max()>0.
        assert br_r.size() == br_x.size() and br_x.size() == self.g_to.size()
        assert self.g_to.size() == self.g_fr.size() and self.g_fr.size()[0] == self.nbranch

        br_r2_x2 = br_r.pow(2) + br_x.pow(2)
        self.br_g = br_r/br_r2_x2
        self.br_b = -br_x/br_r2_x2
        self.tap2 = tap.pow(2)
        self.T_R = tap * torch.cos(shift)
        self.T_I = tap * torch.sin(shift)

        # setup bus_branchidxs_fr, bus_branchidxs_to
        bus_branchidxs_fr, bus_branchidxs_to = [], []
        max_fr_nbranches, max_to_nbranches = 0, 0
        for i,id in enumerate(self.busids):
            branchidxs_fr, branchidxs_to = [], []
            for branch_idx, branch_bus_idx in enumerate(self.bus_i): # for from busidxs
                if i==branch_bus_idx:
                    branchidxs_fr.append(branch_idx)
            bus_branchidxs_fr.append(branchidxs_fr)
            max_fr_nbranches = max(max_fr_nbranches, len(branchidxs_fr))

            for branch_idx, branch_bus_idx in enumerate(self.bus_j): # for to busidxs
                if i==branch_bus_idx:
                    branchidxs_to.append(branch_idx)
            bus_branchidxs_to.append(branchidxs_to)
            max_to_nbranches = max(max_to_nbranches, len(branchidxs_to))

        self.bus_branchidxs_fr = self.nbranch*torch.ones(self.nbus,max_fr_nbranches,dtype=torch.int64).to(device)
        self.bus_branchidxs_to = self.nbranch*torch.ones(self.nbus,max_to_nbranches,dtype=torch.int64).to(device)
        for i,branchidxs_fr in enumerate(bus_branchidxs_fr):
            for j,branchidx in enumerate(branchidxs_fr):
                self.bus_branchidxs_fr[i,j] = branchidx
        for i,branchidxs_to in enumerate(bus_branchidxs_to):
            for j,branchidx in enumerate(branchidxs_to):
                self.bus_branchidxs_to[i,j] = branchidx

        # shunt
        self.gs = torch.zeros(self.nbus,dtype=torch.get_default_dtype()).to(device)
        self.bs = torch.zeros(self.nbus,dtype=torch.get_default_dtype()).to(device)

        for i,id in enumerate(self.shuntids):
            shunt_data = network["shunt"][str(id)]
            shunt_bus_id = shunt_data["shunt_bus"]
            shunt_bus_idx = np.where(self.busids==int(shunt_bus_id))[0][0]
            gs = shunt_data["gs"]
            bs = shunt_data["bs"]
            self.gs[shunt_bus_idx] += gs
            self.bs[shunt_bus_idx] += bs

        # data loading -- to cpu
        print(" Loading Data Instances...",flush=True)

        #self.ndata = int(len(datafiles)*0.12)
        self.va = torch.empty(self.ndata,self.nbus) # ground truth - voltage angle
        self.vm = torch.empty(self.ndata,self.nbus) # ground truth - voltage magnitude
        self.pg = torch.empty(self.ndata,self.ngen) # ground truth - active power generation
        self.qg = torch.empty(self.ndata,self.ngen) # ground truth - reactive power generation
        self.pd = torch.empty(self.ndata,self.nload) # input - active demand
        self.qd = torch.empty(self.ndata,self.nload) # input - reactive demand
        self.objs = torch.empty(self.ndata)

        for i, datafile in enumerate(datafiles):
            if i == 1200:
                break
            instance = json.load(open(datafile,'r'))
            self.pd[i, :] = torch.tensor(instance["pd"])   # active demand
            self.qd[i, :] = torch.tensor(instance["qd"])   # inactive demand
            self.va[i, :] = torch.tensor(instance["va"])  # voltage
            self.vm[i, :] = torch.tensor(instance["vm"])  # voltage
            self.pg[i, :] = torch.tensor(instance["pg"])  # active generation
            self.qg[i, :] = torch.tensor(instance["qg"])  # inactive generation

        self.va = self.va
        self.vm = self.vm
        self.pg = self.pg
        self.qg = self.qg
        self.pd = self.pd
        self.qd = self.qd
        self.va = self.va.to(DEVICE)
        self.dva = self.va[:,self.bus_i] - self.va[:,self.bus_j]

        self.pg_nominal = [  2.45, .0, 0.6, 0.0, 8.236489137147354, 0.0, 1.3203286726306795 ]
        self.qg_nominal = [  0.6518767456319354, 0.4999999995892575, 0.29999996780128824, -0.08, 1.3189829046670793, 0.09, 1.1954603459811538 ]

        pd_extended = torch.zeros(self.ndata,self.nload,self.nbus,dtype=torch.get_default_dtype())
        qd_extended = torch.zeros(self.ndata,self.nload,self.nbus,dtype=torch.get_default_dtype())
        pd_extended[:, torch.arange(self.nload), self.load2bus] = self.pd
        qd_extended[:, torch.arange(self.nload), self.load2bus] = self.qd
        self.pd_bus = pd_extended.sum(dim=1)
        self.qd_bus = qd_extended.sum(dim=1)
        self.neq = 2*self.nbus # power balance (p and q)
        self.nineq = 2*self.nbranch
        self.pad = nn.ConstantPad2d((0,1,0,0),0.) # for padding flow to recover bus idx based tensor # used in 'compute_flow'

        n_gen = 7
        self.T_DETECTION = args['T']
        self.T_NODE = args['T']

        ### RANDOM INTIAL CONDITION (\delta(0), \omega(0)) GENERATION 

        self.init_cond = torch.zeros((self.pd.shape[0],n_gen,2)).to(DEVICE)

        # if args['activate_instability_computation_epoch']==-1:
        #      ### RANDOM INTIAL CONDITION (\delta(0), \omega(0)) 
        #     for i in range(self.pd.shape[0]):
        #         #self.init_cond[i,:3,:] = torch.tensor([(0.01 * torch.rand(1)+0.05, 0.01* torch.rand(1)- 0.005)]).repeat(3,1)
        #         self.init_cond[i,:3,:] = torch.tensor([(0.1 * torch.rand(1), 0.1* torch.rand(1)- 0.05)]).repeat(3,1)
    
        #         #self.init_cond[i,3,:] = torch.tensor([(0.1 * torch.rand(1), 0.06* torch.rand(1)- 0.03)])
        #         self.init_cond[i,3,:] = torch.tensor([(0.01 * torch.rand(1)+0.05, 0.01* torch.rand(1)- 0.005)])
        #         self.init_cond[i,4:,:] = torch.tensor([(0.1 * torch.rand(1), 0.1* torch.rand(1)- 0.05)]).repeat(3,1)
        #         #self.init_cond[i,4:,:] = torch.tensor([(0.01 * torch.rand(1)+0.05, 0.0* torch.rand(1)- 0.0)]).repeat(3,1)
        # else:
        #      ### PSEUDO RANDOM INTIAL CONDITION (\delta(0), \omega(0)) 
        #     for i in range(self.pd.shape[0]):
        #         #self.init_cond[i,:3,:] = torch.tensor([(0.01 * torch.rand(1)+0.05, 0.01* torch.rand(1)- 0.005)]).repeat(3,1)
        #         if self.va[i, 0] < 1.1*.935:
        #             self.init_cond[i,0,:] = torch.tensor([(0.1 * torch.rand(1), 0.01* torch.rand(1)- 0.005)])
        #         else:
        #             self.init_cond[i,0,:] = torch.tensor([(0.1 * torch.rand(1), 0.1* torch.rand(1)- 0.05)])
        #         if self.va[i, 7] < 1.1*.962:
        #             self.init_cond[i,1,:] = torch.tensor([(0.1 * torch.rand(1), 0.01* torch.rand(1)- 0.005)])
        #         else:
        #             self.init_cond[i,1,:] = torch.tensor([(0.1 * torch.rand(1), 0.1* torch.rand(1)- 0.05)])
        #         #if self.va[i, 11] < self.vmmax*.9:
        #         self.init_cond[i,2,:] = torch.tensor([(0.1 * torch.rand(1), 0.01* torch.rand(1)- 0.005)])
        #         #else:
        #         #    self.init_cond[i,2,:] = torch.tensor([(0.1 * torch.rand(1), 0.1* torch.rand(1)- 0.05)])
        #         self.init_cond[i,3,:] = torch.tensor([(0.01 * torch.rand(1)+0.05, 0.01* torch.rand(1)- 0.005)])
        #         self.init_cond[i,4:,:] = torch.tensor([(0.1 * torch.rand(1), 0.1* torch.rand(1)- 0.05)]).repeat(3,1)
        #     #, (0.1 * torch.rand(1), 0.005* torch.rand(1)- 0.0025), (0.1 * torch.rand(1), 0.005* torch.rand(1)- 0.0025)])
        ## pg index first group: 0, 4, 6
        ## v index first group: 0, 7, 11

        for i in range(self.pd.shape[0]):
            for j in range(n_gen):
                if j==0 or j==4 or j==6:
                    if j==0:
                        v_index = 0
                    elif j==4:
                        v_index = 7
                    else:
                        v_index = 11
                    e_prime = math.sqrt((self.pg[i,j]*param['X_d_prime'][0])**2+(self.vm[i,v_index]+self.qg[i,j])**2)/self.vm[i,v_index]
                    #e_prime = math.sqrt((self.pg_nominal[j]*param['X_d_prime'][0])**2+(self.vm[i,v_index]+self.qg_nominal[j])**2)/self.vm[i,v_index]
                    delta_ = np.arcsin(Pm_[0]*param['X_d_prime'][0]/(self.vm[i,v_index]*e_prime)) + self.va[i,v_index]
                
                    if j==4:
                        mean_gauss = .04 #.42
                        high = 3 #2
                    else:
                        mean_gauss = .01
                        high = 10

                    np.random.seed(i)

                    if np.random.randint(1,high)==1:
                        omega = (mean_gauss)* torch.rand(1)
                    else:
                        omega = torch.normal(torch.tensor([mean_gauss]), torch.tensor([0.0025]))
                        if omega>=.05:
                            omega = .049
                    self.init_cond[i,j,:] = torch.tensor([(delta_, omega)])#.repeat(3,1)

                    #print(self.init_cond[i,j,:].size())                    
                    x_input = torch.cat((self.init_cond[i,j,:], self.va[i, v_index].unsqueeze(0), self.vm[i, v_index].unsqueeze(0)),dim=0).to(DEVICE)
                    
                    # y_true_dyn[i,:,:] = get_solution(gen_group, x_input, t).to(DEVICE)

                    # tmp = 1 if torch.abs(y_true_dyn[i, -1,0]) > torch.pi/2 else 0

                    # if tmp==1 and j==0:
                    #     print("Unstable input: ", x_input)
                    #     x_input = torch.cat((self.init_cond[i,j,:], self.va[i, v_index].unsqueeze(0), self.vmmax[v_index].unsqueeze(0)),dim=0).to(DEVICE)
                    #     y_true_dyn[i,:,:] = get_solution(gen_group, x_input, t).to(DEVICE)
                    #     tmp = 1 if torch.abs(y_true_dyn[i, -1,0]) > torch.pi/2 else 0
                    #     if tmp==1:
                    #         print("Still unstable")
                    # true_unstable[gen_group] += tmp

                else:
                    self.init_cond[i,j,:] = torch.tensor([(0.01 * torch.rand(1)+0.05, 0.)])
                    #self.init_cond[i,j,:] = torch.tensor([(0., 0.)])

        print("Gen 1, group 1")
        print("Delta interval")
        print(torch.min(self.init_cond[:,0,0]))
        print(torch.max(self.init_cond[:,0,0]))
        print("Omega interval")
        print(torch.min(self.init_cond[:,0,1]))
        print(torch.max(self.init_cond[:,0,1]))
        
        print("Gen 2, group 1")
        print("Delta interval")
        print(torch.min(self.init_cond[:,4,0]))
        print(torch.max(self.init_cond[:,4,0]))
        print("Omega interval")
        print(torch.min(self.init_cond[:,4,1]))
        print(torch.max(self.init_cond[:,4,1]))

        # for gen_group in range(3):
        #     print(f"Num. of true unstable trajectory: {true_unstable[gen_group]}")
        # print(torch.min(self.init_cond[:,0,1]))
        # print(torch.max(self.init_cond[:,0,1]))
        # print(torch.min(self.init_cond[:,4,1]))
        # print(torch.max(self.init_cond[:,4,1]))
        # print(torch.min(self.init_cond[:,6,1]))
        # print(torch.max(self.init_cond[:,6,1]))

        self.NODE_models = {}

        # PRETRAINED NODE MODELS LOADING
        
        ### Group - GENERATOR BUS

        #    1st  - 1, 8, 12 
        #    2nd  - 3
        #    3rd  - 2, 6, 9

        self.gen_bus_list = [1, 2, 3, 6, 8, 9, 12]
        self.gen_bus_list_0_based = [0, 1, 2, 5, 7, 8, 11]

        self.gen_group_1_list = [1, 8, 12]
        self.gen_group_2_list = [3]
        self.gen_group_3_list = [2, 6, 9]

        self.gen_group_1_list_0_based = [0, 7, 11]
        self.gen_group_2_list_0_based = [2]
        self.gen_group_3_list_0_based = [1, 5, 8]

        for i in self.gen_bus_list:

            if i in self.gen_group_1_list:
                if i == 1:
                    arch_list_NODE = [200] * 2
                    activation_NODE = "SOFTPLUS"
                elif i == 8:
                    arch_list_NODE = [50] * 2
                    activation_NODE = "RELU"
                else:
                    arch_list_NODE = [200] * 2
                    activation_NODE = "TANH"
            elif i in self.gen_group_2_list:
                arch_list_NODE = [200] * 5
                activation_NODE = "RELU"
            elif i in self.gen_group_3_list:   ### BEST_MODEL_3_stable
                arch_list_NODE = [50] * 2
                activation_NODE = "RELU"
            # elif i in self.gen_group_3_list: ### BEST_MODEL_3
            #     arch_list_NODE = [50] * 2
            #     activation_NODE = "RELU"
            # elif i in self.gen_group_3_list:   ### BEST_MODEL_10
            #     arch_list_NODE = [50] * 2
            #     activation_NODE = "TANH"

            arch_list_NODE.insert(0,4) 
            arch_list_NODE.append(4)
            self.NODE_models[f'gen{i}'] = SwingNN(arch_list_NODE, activation_NODE)

            if i in self.gen_group_1_list:
                if i == 1:
                    self.NODE_models[f'gen{i}'].load_state_dict(torch.load(f'NODE model/57 bus/best_model_{1}_17.pt', map_location=torch.device('cpu')))
                elif i == 8:
                    self.NODE_models[f'gen{i}'].load_state_dict(torch.load(f'NODE model/57 bus/best_model_{2}_2.pt', map_location=torch.device('cpu')))
                else:
                    self.NODE_models[f'gen{i}'].load_state_dict(torch.load(f'NODE model/57 bus/best_model_{3}_22.pt', map_location=torch.device('cpu')))
            elif i in self.gen_group_2_list:
                self.NODE_models[f'gen{i}'].load_state_dict(torch.load(f'NODE model/57 bus/best_model_{2}.pt', map_location=torch.device('cpu')))
            elif i in self.gen_group_3_list:
                #self.NODE_models[f'gen{i}'].load_state_dict(torch.load(f'NODE model/57 bus/best_model_{3}.pt', map_location=torch.device('cpu')))
                #self.NODE_models[f'gen{i}'].load_state_dict(torch.load(f'NODE model/57 bus/best_model_{10}.pt', map_location=torch.device('cpu')))
                self.NODE_models[f'gen{i}'].load_state_dict(torch.load(f'NODE model/57 bus/best_model_{3}_stable.pt', map_location=torch.device('cpu')))
            self.NODE_models[f'gen{i}'] = self.NODE_models[f'gen{i}'].to(DEVICE)

        t_NODE = torch.linspace(0, self.T_NODE, 100).reshape(-1, 1).to(DEVICE)

        total_true_unstable, total_false_unstable, total_detected_true_unstable = 0, 0, 0
        true_unstable = np.zeros(3)
        false_unstable = np.zeros(3)
        detected_true_unstable = np.zeros(3)
        self.T = args['T']
        t = torch.linspace(0, self.T, 100).reshape(-1, 1).to(DEVICE)
        y_true_dyn = torch.zeros((self.pd.shape[0],t.shape[0],4)).to(DEVICE)
        bad_unstable = 0
        for i in range(1000,1100):
            for j in range(n_gen):
                if j%2==0 and j!=2: 
                    if j==0:
                        gen_index = 0
                        v_index = 0
                    elif j==4:
                        gen_index = 1
                        v_index = 7
                    else:
                        gen_index = 2
                        v_index = 11
                    x_input = torch.tensor([self.init_cond[i,j,0], self.init_cond[i,j,1], self.va[i,v_index], self.vm[i,v_index]]).to(device)
                    y_true_dyn[i-1000,:,:] = get_solution(gen_index, x_input, t).to(device)
                    tmp = 1 if torch.abs(y_true_dyn[i-1000, -1,0]) > torch.pi/2 else 0
                    #print(f"Generator {j} unstable wit, x(0)=",x_input)
                    true_unstable[gen_index] += tmp
                    y_pred_dyn = odeint(self.NODE_models[f'gen{self.gen_bus_list[j]}'], x_input, t_NODE.squeeze(), method=numerical_method) #
                    #y_pred_dyn = torch.swapaxes(y_pred_dyn, 0, 1) # batch, time, features = 4
                    if (torch.abs(y_pred_dyn[-1, 0]) > torch.pi/2 and torch.abs(y_true_dyn[i-1000, -1,0]) < torch.pi/2) or (torch.abs(y_pred_dyn[-1, 0]) < torch.pi/2 and torch.abs(y_true_dyn[i-1000, -1,0]) > torch.pi/2):
                        false_unstable[gen_index]+=1
                    if tmp ==1:
                        x_input = torch.tensor([self.init_cond[i,j,0], self.init_cond[i,j,1], self.va[i,v_index], 1.08]).to(device)
                        y_true_dyn[i-1000,:,:] = get_solution(gen_index, x_input, t).to(device)
                        tmp = 1 if torch.abs(y_true_dyn[i-1000, -1,0]) > torch.pi/2 else 0
                        if tmp==1:
                            
                            bad_unstable += 1
                            x_input = torch.tensor([self.init_cond[i,j,0], self.init_cond[i,j,1]/2, self.va[i,v_index], 1.08]).to(device)
                            y_true_dyn[i-1000,:,:] = get_solution(gen_index, x_input, t).to(device)
                            if torch.abs(y_true_dyn[i-1000, -1,0]) > torch.pi/2:
                                print("Still unstable with omega'=omega/2")
                            else:
                                self.init_cond[i,j,1] /= 2 #torch.tensor([delta,omega_/2])

        print(bad_unstable)
        print("Number of false unstable")
        print(false_unstable)
        print("Number of true unstable")
        print(true_unstable)
        print("Done.")
        
    def __str__(self):
        return self.acopf_name

    @property
    def X(self):
        return {"pd":self.pd, "qd":self.qd, "pd_bus":self.pd_bus, "qd_bus":self.qd_bus}

    @property
    def trainX(self):
        return {"pd":self.pd[:int(self.ndata * self.train_frac)], "qd":self.qd[:int(self.ndata * self.train_frac)],
                "pd_bus":self.pd_bus[:int(self.ndata * self.train_frac)], "qd_bus":self.qd_bus[:int(self.ndata * self.train_frac)]}

    @property
    def validX(self):
        return {"pd":self.pd[int(self.ndata * self.train_frac):int(self.ndata * (self.train_frac + self.valid_frac))],
                "qd":self.qd[int(self.ndata * self.train_frac):int(self.ndata * (self.train_frac + self.valid_frac))],
                "pd_bus":self.pd_bus[int(self.ndata * self.train_frac):int(self.ndata * (self.train_frac + self.valid_frac))],
                "qd_bus":self.qd_bus[int(self.ndata * self.train_frac):int(self.ndata * (self.train_frac + self.valid_frac))]}

    @property
    def testX(self):
        return {"pd":self.pd[int(self.ndata * (self.train_frac + self.valid_frac)):],
                "qd":self.qd[int(self.ndata * (self.train_frac + self.valid_frac)):],
                "pd_bus":self.pd_bus[int(self.ndata * (self.train_frac + self.valid_frac)):],
                "qd_bus":self.qd_bus[int(self.ndata * (self.train_frac + self.valid_frac)):]}


    @property
    def Y(self):
        return {"va":self.va, "vm":self.vm, "pg":self.pg, "qg":self.qg, "dva": self.dva}

    @property
    def trainY(self):
        return {"va":self.va[:int(self.ndata*self.train_frac)],
                "vm":self.vm[:int(self.ndata*self.train_frac)],
                "pg":self.pg[:int(self.ndata*self.train_frac)],
                "qg":self.qg[:int(self.ndata*self.train_frac)],
                "dva":self.dva[:int(self.ndata*self.train_frac)]
                }

    @property
    def validY(self):
        return {"va":self.va[int(self.ndata*self.train_frac):int(self.ndata*(self.train_frac + self.valid_frac))],
                "vm":self.vm[int(self.ndata*self.train_frac):int(self.ndata*(self.train_frac + self.valid_frac))],
                "pg":self.pg[int(self.ndata*self.train_frac):int(self.ndata*(self.train_frac + self.valid_frac))],
                "qg":self.qg[int(self.ndata*self.train_frac):int(self.ndata*(self.train_frac + self.valid_frac))],
                "dva":self.dva[int(self.ndata*self.train_frac):int(self.ndata*(self.train_frac + self.valid_frac))]
                }

    @property
    def testY(self):
        return {"va":self.va[int(self.ndata*(self.train_frac + self.valid_frac)):],
                "vm":self.vm[int(self.ndata*(self.train_frac + self.valid_frac)):],
                "pg":self.pg[int(self.ndata*(self.train_frac + self.valid_frac)):],
                "qg":self.qg[int(self.ndata*(self.train_frac + self.valid_frac)):],
                "dva":self.dva[int(self.ndata*(self.train_frac + self.valid_frac)):]
                }
    
    def get_init_cond(self,a,b):
        return self.init_cond[a:b,:,:]
    
    def get_dva(va,self):
        self.dva = self.va[:, self.bus_i] - self.va[:, self.bus_j]

    def compute_flow(self, vm, dva, verbose=False):
        vmi = vm[:,self.bus_i]
        vmj = vm[:,self.bus_j]
        vmi2 = vmi.pow(2)
        vmj2 = vmj.pow(2)
        vmij = vmi*vmj

        vaij_cos = torch.cos(dva)
        vaij_sin = torch.sin(dva)

        pf_fr = (1/self.tap2) * (self.br_g + self.g_fr) * vmi2\
                            + ((-self.br_g * self.T_R + self.br_b * self.T_I)/self.tap2) * (vmij) * vaij_cos\
                            + ((-self.br_b * self.T_R - self.br_g * self.T_I)/self.tap2) * (vmij) * vaij_sin
        pf_to = (self.br_g + self.g_to) * vmj2\
                            + ((-self.br_g * self.T_R - self.br_b * self.T_I)/self.tap2) * (vmij) * vaij_cos\
                            + ((-self.br_b * self.T_R + self.br_g * self.T_I)/self.tap2) * (vmij) * (-vaij_sin)
        qf_fr = - (1/self.tap2) * (self.br_b + self.b_fr) * vmi2\
                            - ((-self.br_b * self.T_R - self.br_g * self.T_I)/self.tap2) * (vmij) * vaij_cos\
                            + ((-self.br_g * self.T_R + self.br_b * self.T_I)/self.tap2) * (vmij) * vaij_sin
        qf_to = -(self.br_b + self.b_to) * vmj2\
                            - ((-self.br_b * self.T_R + self.br_g * self.T_I)/self.tap2) * (vmij) * vaij_cos\
                            + ((-self.br_g * self.T_R - self.br_b * self.T_I)/self.tap2) * (vmij) * (-vaij_sin)
        if verbose:
            print("compute_flow check params:: %.4f | %.4f | %.4f | %.4f"%(self.tap2.max(), self.br_g.max(), self.g_fr.max(), vmi2.max()),flush=True)
            term1 = (1/self.tap2) * (self.br_g + self.g_fr) * vmi2
            term2 = ((-self.br_g * self.T_R + self.br_b * self.T_I)/self.tap2) * (vmij) * vaij_cos
            term3 = ((-self.br_b * self.T_R - self.br_g * self.T_I)/self.tap2) * (vmij) * vaij_sin
            print("compute_flow:: %.4f | %.4f | %.4f "%(term1.max(), term2.max(), term3.max()),flush=True)

        pf_fr_pad = self.pad(pf_fr)
        pf_to_pad = self.pad(pf_to)

        qf_fr_pad = self.pad(qf_fr)
        qf_to_pad = self.pad(qf_to)

        pf_fr_bus = pf_fr_pad[:,self.bus_branchidxs_fr].sum(dim=2)
        pf_to_bus = pf_to_pad[:,self.bus_branchidxs_to].sum(dim=2)
        qf_fr_bus = qf_fr_pad[:,self.bus_branchidxs_fr].sum(dim=2)
        qf_to_bus = qf_to_pad[:,self.bus_branchidxs_to].sum(dim=2)

        return {"pf_fr":pf_fr, "pf_to":pf_to, "qf_fr":qf_fr, "qf_to":qf_to,
                "pf_fr_bus":pf_fr_bus, "pf_to_bus":pf_to_bus, "qf_fr_bus":qf_fr_bus, "qf_to_bus": qf_to_bus}

    def obj_fn(self, X, Y):
        pg = Y['pg'].to(self.device)
        obj = self.quad_cost*pg**2 + self.lin_cost*pg + self.const_cost
        obj = obj.sum(dim=1) # sum over generators
        obj /= 1 #self.obj_scaler
        return obj

    def opt_gap(self, X, Y, Ygt):
        obj_app = self.obj_fn(X,Y)
        obj_gt = self.obj_fn(X,Ygt)
        return (obj_app-obj_gt).abs()/obj_gt.abs()

    def instab_resid(self, X, Y, epoch, args, batch_idx):

        #    1st  - 1, 8, 12 
        #    2nd  - 3
        #    3rd  - 2, 6, 9

        #    self.gen_bus_list_0_based = [0, 1, 2, 5, 7, 8, 11]

        #    self.gen_group_1_list = [1, 8, 12]
        #    self.gen_group_2_list = [3]
        #    self.gen_group_3_list = [2, 6, 9]

        #    self.gen_group_1_list_0_based = [0, 7, 11]
        #    self.gen_group_2_list_0_based = [2]
        #    self.gen_group_3_list_0_based = [1, 5, 8]

        #print(batch_idx)
        t_NODE = torch.linspace(0, self.T_NODE, 100).reshape(-1, 1).to(DEVICE)
        delta_pred_dyn = torch.zeros((Y['pg'].shape[0], 7), device=DEVICE)
        total_true_unstable, total_false_unstable, total_detected_true_unstable = 0, 0, 0
        if epoch>args['activate_instability_computation_epoch']: ### ACTIVATE INSTABILITY LOSS COMPUTATION AFTER 1 (HYPERPARAM) EPOCH
            for z in range(7):
                #if z!=1 and z!=3 and z!=5 and z!=2:
                if z!=1 and z!=3 and z!=5 and z!=2 and z!=0:
                    gen_index = self.gen_bus_list_0_based[z]
                    ### RETRIEVING PRE-GENERATED INITIAL CONDITION  (\delta(0) AND \omega(0)) and ATTACHING THEM TO THE PREDICTED V AND \THETA                
                    if batch_idx>=0:
                        x_input = torch.cat((self.init_cond[batch_idx*args['batchsize']:(batch_idx+1)*args['batchsize'],z,:], Y['va'][:, gen_index].reshape((-1,1)), Y['vm'][:, gen_index].reshape((-1,1))), dim=1).to(DEVICE)
                    elif batch_idx==-1:
                        #x_input = torch.cat((init_cond[80:160,k,:], theta_pred[:, z].reshape((-1,1)), V_pred[:, z].reshape((-1,1))), dim=1).to(device)
                        x_input = torch.cat((self.init_cond[1000:1100,z,:], Y['va'][:, gen_index].reshape((-1,1)), Y['vm'][:, gen_index].reshape((-1,1))), dim=1).to(DEVICE)
                    elif batch_idx==-10:
                        x_input = torch.cat((self.init_cond[1100:,z,:], Y['va'][:, gen_index].reshape((-1,1)), Y['vm'][:, gen_index].reshape((-1,1))), dim=1).to(DEVICE)
                        #x_input = torch.cat((init_cond[160:,k,:], theta_pred[:, z].reshape((-1,1)), V_pred[:, z].reshape((-1,1))), dim=1).to(device)
                    #t_node_start = time.time()
                    y_pred_dyn = odeint(self.NODE_models[f'gen{self.gen_bus_list[z]}'], x_input, t_NODE.squeeze(), method='rk4') #
                    
                    y_pred_dyn = torch.swapaxes(y_pred_dyn, 0, 1) # batch, time, features = 4

                    delta_pred_dyn[:, z] = y_pred_dyn[:, -1, 0]

                    mask = torch.abs(delta_pred_dyn[:, z]) > torch.pi/2
                    total_detected_true_unstable += torch.sum(mask).item()

                    #print(f"gen: {self.gen_bus_list[z]}, num. of detected unstable trajectory: {torch.sum(mask)}")
                    #delta_pred_dyn[mask, k] = torch.pi
                    #not_mask = ~mask
                    #delta_pred_dyn[not_mask, k] = 0
                    #indices = torch.nonzero(mask, as_tuple=False)   ### USED TO COMPUTE THE TRUE TRAJECTORIES ONLY WHEN THE NEURAL ODE MODEL DETECTS UN UNSTABLE TRAJECTORY

                    time_vector = t_NODE
                    time_vector = time_vector.to(DEVICE)
                    # # # #y_true_dyn = torch.zeros((indices.shape[0],time_vector.shape[0],4))
                    y_true_dyn = torch.zeros((y_pred_dyn.shape[0],time_vector.shape[0],4)).to(DEVICE)
                    # if batch_idx==-10:
                    #     for p in range(y_pred_dyn.shape[0]):
                    #     # # #for p,idx in enumerate(indices):
                    #         if z==4 or z==6:
                    #             gen_group = 0
                    #         y_true_dyn[p,:,:] = get_solution(gen_group, x_input[p,:], time_vector).to(DEVICE)
                    #         #if torch.abs(y_true_dyn[p, -1,0]) > torch.pi/2:
                    #             #print(f"Delta(0) : {x_input[p,0]} omega(0) : {x_input[p,1]} theta_pred : {x_input[p,2]} v_pred {x_input[p,3]}")

                    #     true_unstable = torch.abs(y_true_dyn[:, -1,0]) > torch.pi/2
                    #     # # # total_true_unstable += torch.sum(true_unstable).item()
                    #     print(f"z: {z+1}, num. of true unstable trajectory: {torch.sum(true_unstable)}")
                    #     # #print(f"K: {k+1}, num. of true unstable trajectory with V=Vmax: {torch.sum(unstable_max)}")
                    #     total_false_unstable += torch.sum(mask & ~true_unstable).item()
                    #     print(f"z: {z+1}, num. of false unstable trajectory: {np.sum(total_false_unstable)}")

                    # if epoch>1 and collect and batch_idx!=-1 and collected<theta_pred_data.shape[0]:
                    #     theta_pred_data[collected:collected+args['batchsize'], k] = theta_pred[:args['batchsize'],k]
                    #     V_pred_data[collected:collected+args['batchsize'], k] = V_pred[:args['batchsize'],k]
                    #     collected += args['batchsize']

                    # if collected == theta_pred_data.shape[0]:
                    #     torch.save(theta_pred_data, f'theta_pred_data_{args["id"]}.pt')
                    #     torch.save(V_pred_data, f'V_pred_data_{args["id"]}.pt')
                    
                    #if batch_idx>50:
                    #    # Print the indices
                    #    print("Generator : ",k+1)
                    #    print("Indices where abs(delta_pred_dyn[:, k]) > pi/2:")
                    #    print(indices)
            loss_instability = torch.sum(F.relu(torch.abs(delta_pred_dyn) - torch.pi/2))
        else:
            loss_instability = torch.tensor([0]).to(DEVICE)
            total_true_unstable, total_false_unstable, total_detected_true_unstable = -1, -1, -1
        print(f"Batch {batch_idx}, detected unstable: {total_detected_true_unstable}")
        return loss_instability
    
    def instab_resid_analysis_OOD(self, X, Y, epoch, args, batch_idx):

        #    1st  - 1, 8, 12 
        #    2nd  - 3
        #    3rd  - 2, 6, 9

        #    self.gen_bus_list_0_based = [0, 1, 2, 5, 7, 8, 11]

        #    self.gen_group_1_list = [1, 8, 12]
        #    self.gen_group_2_list = [3]
        #    self.gen_group_3_list = [2, 6, 9]

        #    self.gen_group_1_list_0_based = [0, 7, 11]
        #    self.gen_group_2_list_0_based = [2]
        #    self.gen_group_3_list_0_based = [1, 5, 8]

        t_NODE = torch.linspace(0, self.T_NODE, 100).reshape(-1, 1).to(DEVICE)
        t_DETECTION = torch.linspace(0, self.T_DETECTION, 100).reshape(-1, 1).to(DEVICE)

        delta_pred_dyn = torch.zeros((Y['pg'].shape[0], 7), device=DEVICE)
        total_true_unstable, total_false_unstable, total_detected_true_unstable = 0, 0, 0
        
        if epoch>args['activate_instability_computation_epoch']: ### ACTIVATE INSTABILITY LOSS COMPUTATION AFTER 1 (HYPERPARAM) EPOCH
            for z in range(7):
                t_solver = 0
                t_node = 0
                t_node_list = []
                t_solver_list = [] 
                #if z!=1 and z!=3 and z!=5 and z!=2:
                if z!=2 and z!=1 and z!=3 and z!=5 and z!=0:
                    gen_index = self.gen_bus_list_0_based[z]
                    ### RETRIEVING PRE-GENERATED INITIAL CONDITION  (\delta(0) AND \omega(0)) and ATTACHING THEM TO THE PREDICTED V AND \THETA                
                    if batch_idx>=0:
                        x_input = torch.cat((self.init_cond[batch_idx*args['batchsize']:(batch_idx+1)*args['batchsize'],z,:], Y['va'][:, gen_index].reshape((-1,1)), Y['vm'][:, gen_index].reshape((-1,1))), dim=1).to(DEVICE)
                    elif batch_idx==-1:
                        #x_input = torch.cat((self.init_cond[1000:1100,z,:], Y['va'][:, gen_index].reshape((-1,1)), Y['vm'][:, gen_index].reshape((-1,1))), dim=1).to(DEVICE)
                        #x_input = torch.cat((self.init_cond[1000:1100,z,:], Y['va'][:, gen_index].reshape((-1,1)), Y['vm'][:, gen_index].reshape((-1,1))), dim=1).to(DEVICE)
                        x_input = torch.cat((self.init_cond[1000:1100,z,:], self.va[1000:1100, gen_index], self.va[1000:1100, gen_index]), dim=1).to(DEVICE)
                    elif batch_idx==-10:
                        #x_input = torch.cat((self.init_cond[600:,z,:], Y['va'][:, gen_index].reshape((-1,1)), Y['vm'][:, gen_index].reshape((-1,1))), dim=1).to(DEVICE)
                        x_input = torch.cat((self.init_cond[1100:,z,:], Y['va'][:, gen_index].reshape((-1,1)), Y['vm'][:, gen_index].reshape((-1,1))), dim=1).to(DEVICE)
                    #else:
                    #    x_input = torch.cat((self.init_cond[1100:1101,z,:], Y['va'][:, gen_index].reshape((-1,1)), Y['vm'][:, gen_index].reshape((-1,1))), dim=1).to(DEVICE)
                    t_node_start = time.time()
                    y_pred_dyn = odeint(self.NODE_models[f'gen{self.gen_bus_list[z]}'], x_input, t_NODE.squeeze(), method=numerical_method) #
                    t_node_end = time.time()
                    t_node_list.append((t_node_end-t_node_start)/Y['va'].shape[0])
                    y_pred_dyn = torch.swapaxes(y_pred_dyn, 0, 1) # batch, time, features = 4

                    delta_pred_dyn[:, z] = y_pred_dyn[:, -1, 0]

                    mask = torch.abs(delta_pred_dyn[:, z]) > torch.pi/2
                    total_detected_true_unstable += torch.sum(mask).item()

                    #print(f"gen: {self.gen_bus_list[z]}, num. of detected unstable trajectory: {torch.sum(mask)}")
                    #delta_pred_dyn[mask, k] = torch.pi
                    #not_mask = ~mask
                    #delta_pred_dyn[not_mask, k] = 0
                    #indices = torch.nonzero(mask, as_tuple=False)   ### USED TO COMPUTE THE TRUE TRAJECTORIES ONLY WHEN THE NEURAL ODE MODEL DETECTS UN UNSTABLE TRAJECTORY

                    # time_vector = t
                    # time_vector = time_vector.to(DEVICE)
                    # # # #y_true_dyn = torch.zeros((indices.shape[0],time_vector.shape[0],4))
                    if batch_idx==-1:
                        y_true_dyn = torch.zeros((y_pred_dyn.shape[0],t_DETECTION.shape[0],4)).to(DEVICE)

                        for p in range(y_pred_dyn.shape[0]):
                        # # # #for p,idx in enumerate(indices):
                            if z==0: #or z==4 or z==6: #z<3:
                                gen_group = 0
                            elif z==4: 
                                gen_group = 1
                            elif z==6: 
                                gen_group = 2

                            t_solver_start = time.time()
                            y_true_dyn[p,:,:] = get_solution(gen_group, x_input[p,:], t_DETECTION).to(DEVICE)
                            t_solver_end = time.time()
                            t_solver += t_solver_end-t_solver_start
                            t_solver_list.append(t_solver_end-t_solver_start)
                            #  if p<10 and z==0:
                            #     print(x_input[p,:])
                            #     print("Sol: ", y_true_dyn[p, -1,0])
                        #     #if torch.abs(y_true_dyn[p, -1,0]) > torch.pi/2:
                        #     #    print(f"Delta(0) : {x_input[p,0]} omega(0) : {x_input[p,1]} theta_pred : {x_input[p,2]} v_pred {x_input[p,3]}")
                        # print("Sol: ", y_true_dyn[0, -1,0])
                        true_unstable = torch.abs(y_true_dyn[:, -1,0]) > torch.pi/2
                        total_true_unstable += torch.sum(true_unstable).item()
                        #print(f"Gen: {gen_index+1}, num. of true unstable trajectory: {torch.sum(true_unstable)}")
                        # #print(f"K: {k+1}, num. of true unstable trajectory with V=Vmax: {torch.sum(unstable_max)}")
                        total_false_unstable += torch.sum(mask & ~true_unstable).item()

                        #print("NODE computational time: ",(t_node_end-t_node_start)/p)
                        # if epoch>1 and collect and batch_idx!=-1 and collected<theta_pred_data.shape[0]:
                        #     theta_pred_data[collected:collected+args['batchsize'], k] = theta_pred[:args['batchsize'],k]
                        #     V_pred_data[collected:collected+args['batchsize'], k] = V_pred[:args['batchsize'],k]
                        #     collected += args['batchsize']

                        # if collected == theta_pred_data.shape[0]:
                        #     torch.save(theta_pred_data, f'theta_pred_data_{args["id"]}.pt')
                        #     torch.save(V_pred_data, f'V_pred_data_{args["id"]}.pt')
                        
                        #if batch_idx>50:
                        #    # Print the indices
                        #    print("Generator : ",k+1)
                        #    print("Indices where abs(delta_pred_dyn[:, k]) > pi/2:")
                        #    print(indices)
            loss_instability = torch.sum(F.relu(torch.abs(delta_pred_dyn) - torch.pi/2))
        else:
            loss_instability = torch.tensor([0]).to(DEVICE)
            total_true_unstable, total_false_unstable, total_detected_true_unstable = -1, -1, -1
        #print(f"Batch {batch_idx}, detected unstable: {total_detected_true_unstable}")
        return loss_instability,  total_true_unstable, total_false_unstable, total_detected_true_unstable
    

    def instab_resid_analysis(self, X, Y, epoch, args, batch_idx):

        #    1st  - 1, 8, 12 
        #    2nd  - 3
        #    3rd  - 2, 6, 9

        #    self.gen_bus_list_0_based = [0, 1, 2, 5, 7, 8, 11]
        #    self.gen_group_1_list = [1, 8, 12]
        #    self.gen_group_2_list = [3]
        #    self.gen_group_3_list = [2, 6, 9]

        #    self.gen_group_1_list_0_based = [0, 7, 11]
        #    self.gen_group_2_list_0_based = [2]
        #    self.gen_group_3_list_0_based = [1, 5, 8]

        t_NODE = torch.linspace(0, self.T_NODE, 100).reshape(-1, 1).to(DEVICE)
        t_DETECTION = torch.linspace(0, self.T_DETECTION, 100).reshape(-1, 1).to(DEVICE)

        delta_pred_dyn = torch.zeros((Y['pg'].shape[0], 7), device=DEVICE)
        total_true_unstable, total_false_unstable, total_detected_true_unstable = 0, 0, 0
        
        if epoch>args['activate_instability_computation_epoch']: ### ACTIVATE INSTABILITY LOSS COMPUTATION AFTER 1 (HYPERPARAM) EPOCH
            for z in range(7):
                t_solver = 0
                t_node = 0
                t_node_list = []
                t_solver_list = [] 
                #if z!=1 and z!=3 and z!=5 and z!=2:
                if z!=2 and z!=1 and z!=3 and z!=5 and z!=0:
                    gen_index = self.gen_bus_list_0_based[z]
                    ### RETRIEVING PRE-GENERATED INITIAL CONDITION  (\delta(0) AND \omega(0)) and ATTACHING THEM TO THE PREDICTED V AND \THETA                
                    if batch_idx>=0:
                        x_input = torch.cat((self.init_cond[batch_idx*args['batchsize']:(batch_idx+1)*args['batchsize'],z,:], Y['va'][:, gen_index].reshape((-1,1)), Y['vm'][:, gen_index].reshape((-1,1))), dim=1).to(DEVICE)
                    elif batch_idx==-1:
                        #x_input = torch.cat((self.init_cond[1000:1100,z,:], Y['va'][:, gen_index].reshape((-1,1)), Y['vm'][:, gen_index].reshape((-1,1))), dim=1).to(DEVICE)
                        x_input = torch.cat((self.init_cond[1000:1100,z,:], Y['va'][:, gen_index].reshape((-1,1)), Y['vm'][:, gen_index].reshape((-1,1))), dim=1).to(DEVICE)
                    elif batch_idx==-10:
                        #x_input = torch.cat((self.init_cond[600:,z,:], Y['va'][:, gen_index].reshape((-1,1)), Y['vm'][:, gen_index].reshape((-1,1))), dim=1).to(DEVICE)
                        x_input = torch.cat((self.init_cond[1100:,z,:], Y['va'][:, gen_index].reshape((-1,1)), Y['vm'][:, gen_index].reshape((-1,1))), dim=1).to(DEVICE)
                    #else:
                    #    x_input = torch.cat((self.init_cond[1100:1101,z,:], Y['va'][:, gen_index].reshape((-1,1)), Y['vm'][:, gen_index].reshape((-1,1))), dim=1).to(DEVICE)
                    t_node_start = time.time()
                    y_pred_dyn = odeint(self.NODE_models[f'gen{self.gen_bus_list[z]}'], x_input, t_NODE.squeeze(), method=numerical_method) #
                    t_node_end = time.time()
                    t_node_list.append((t_node_end-t_node_start)/Y['va'].shape[0])
                    y_pred_dyn = torch.swapaxes(y_pred_dyn, 0, 1) # batch, time, features = 4

                    delta_pred_dyn[:, z] = y_pred_dyn[:, -1, 0]

                    mask = torch.abs(delta_pred_dyn[:, z]) > torch.pi/2
                    total_detected_true_unstable += torch.sum(mask).item()

                    #print(f"gen: {self.gen_bus_list[z]}, num. of detected unstable trajectory: {torch.sum(mask)}")
                    #delta_pred_dyn[mask, k] = torch.pi
                    #not_mask = ~mask
                    #delta_pred_dyn[not_mask, k] = 0
                    #indices = torch.nonzero(mask, as_tuple=False)   ### USED TO COMPUTE THE TRUE TRAJECTORIES ONLY WHEN THE NEURAL ODE MODEL DETECTS UN UNSTABLE TRAJECTORY
                    # time_vector = t
                    # time_vector = time_vector.to(DEVICE)
                    # # # #y_true_dyn = torch.zeros((indices.shape[0],time_vector.shape[0],4))

                    if batch_idx==-1:
                        y_true_dyn = torch.zeros((y_pred_dyn.shape[0],t_DETECTION.shape[0],4)).to(DEVICE)
                        for p in range(y_pred_dyn.shape[0]):

                        # # # #for p,idx in enumerate(indices):
                            if z==0: #or z==4 or z==6: #z<3:
                                gen_group = 0
                            elif z==4: 
                                gen_group = 1
                            elif z==6: 
                                gen_group = 2

                            t_solver_start = time.time()
                            y_true_dyn[p,:,:] = get_solution(gen_group, x_input[p,:], t_DETECTION).to(DEVICE)
                            if torch.abs(y_true_dyn[p, -1,0]) > torch.pi/2:
                                print(x_input)
                            t_solver_end = time.time()
                            t_solver += t_solver_end-t_solver_start
                            t_solver_list.append(t_solver_end-t_solver_start)
                            # if z%2==0:
                            #     if torch.abs(y_true_dyn[p, -1,0]) > torch.pi/2:
                            #         print(f"Input: {x_input[p,:]}")
                            #         print(f"Output at t={self.T_DETECTION}: {y_true_dyn[p, -1,:]}")
                        #     print("Sol: ", y_true_dyn[p, -1,0])
                        #     #if torch.abs(y_true_dyn[p, -1,0]) > torch.pi/2:
                        #     #    print(f"Delta(0) : {x_input[p,0]} omega(0) : {x_input[p,1]} theta_pred : {x_input[p,2]} v_pred {x_input[p,3]}")
                        # print("Sol: ", y_true_dyn[0, -1,0])
                        true_unstable = torch.abs(y_true_dyn[:, -1,0]) > torch.pi/2
                        total_true_unstable += torch.sum(true_unstable).item()
                        print(f"Gen: {gen_index+1}, num. of true unstable trajectory: {torch.sum(true_unstable)}")
                        # #print(f"K: {k+1}, num. of true unstable trajectory with V=Vmax: {torch.sum(unstable_max)}")
                        total_false_unstable += torch.sum(mask & ~true_unstable).item()

                        # print("Avg. solver computational time: ", np.mean(np.array(t_solver_list)))
                        # print("Avg. NODE computational time: ", np.mean(np.array(t_node_list)))
                        # print("STD. solver computational time: ", np.std(np.array(t_solver_list)))
                        # print("STD. NODE computational time: ", np.std(np.array(t_node_list)))

                        #print("NODE computational time: ",(t_node_end-t_node_start)/p)
                        # if epoch>1 and collect and batch_idx!=-1 and collected<theta_pred_data.shape[0]:
                        #     theta_pred_data[collected:collected+args['batchsize'], k] = theta_pred[:args['batchsize'],k]
                        #     V_pred_data[collected:collected+args['batchsize'], k] = V_pred[:args['batchsize'],k]
                        #     collected += args['batchsize']

                        # if collected == theta_pred_data.shape[0]:
                        #     torch.save(theta_pred_data, f'theta_pred_data_{args["id"]}.pt')
                        #     torch.save(V_pred_data, f'V_pred_data_{args["id"]}.pt')
                        
                        #if batch_idx>50:
                        #    # Print the indices
                        #    print("Generator : ",k+1)
                        #    print("Indices where abs(delta_pred_dyn[:, k]) > pi/2:")
                        #    print(indices)
            loss_instability = torch.sum(F.relu(torch.abs(delta_pred_dyn) - torch.pi/2))
        else:
            loss_instability = torch.tensor([0]).to(DEVICE)
            total_true_unstable, total_false_unstable, total_detected_true_unstable = -1, -1, -1
        #print(f"Batch {batch_idx}, detected unstable: {total_detected_true_unstable}")
        return loss_instability,  total_true_unstable, total_false_unstable, total_detected_true_unstable
    

    def ineq_resid(self, X ,Y):
        if "flow" in Y.keys():
            flow = Y["flow"]
            dva = Y["dva"]
        else:
            vm = Y["vm"].to(self.device)
            dva = Y["dva"].to(self.device)
            flow = self.compute_flow(vm,dva)

        pf_fr = flow["pf_fr"]; qf_fr = flow["qf_fr"]
        pf_to = flow["pf_to"]; qf_to = flow["qf_to"]
        tl_fr = pf_fr**2 + qf_fr**2 - self.thermal_limit**2
        tl_to = pf_to**2 + qf_to**2 - self.thermal_limit**2

        ineq = torch.cat([tl_fr,tl_to],dim=1)
        return ineq

    def get_dva(self, va):
        return self.va[:, self.bus_i] - self.va[:, self.bus_j]

    def ineq_dist(self, X, Y):
        resids = self.ineq_resid(X, Y)
        return torch.clamp(resids, 0.)

    def eq_resid(self, X, Y):
        # power balance at each bus
        pd_bus = X['pd_bus'].to(self.device)
        qd_bus = X['qd_bus'].to(self.device)

        vm = Y["vm"].to(self.device)
        if "flow" in Y.keys():
            flow = Y["flow"]
            qg_bus = Y["qg_bus"]
            pg_bus = Y["pg_bus"]
        else:
            dva = Y["dva"].to(self.device)
            flow = self.compute_flow(vm,dva)
            pg = Y["pg"].to(self.device)
            qg = Y["qg"].to(self.device)
            pg_pad = self.pad(pg)
            pg_bus = pg_pad[:,self.bus_genidxs].sum(dim=2)
            qg_pad = self.pad(qg)
            qg_bus = qg_pad[:,self.bus_genidxs].sum(dim=2)

        pf_fr_bus = flow["pf_fr_bus"]; pf_to_bus = flow["pf_to_bus"]
        qf_fr_bus = flow["qf_fr_bus"]; qf_to_bus = flow["qf_to_bus"]
        balance_p = pg_bus - pd_bus - pf_to_bus - pf_fr_bus - self.gs*vm**2
        balance_q = qg_bus - qd_bus - qf_to_bus - qf_fr_bus + self.bs*vm**2
        eq = torch.cat([balance_p,balance_q], dim=1)
        return eq


###################################################################
# NEURAL NETWORKS
###################################################################

def init_layer(nlayer, nhidden, nin, nout, primal=True):
    layer_sizes = [nin]
    layer_sizes += nlayer*[nhidden]
    layers = reduce(operator.add, [[nn.Linear(a,b), nn.ReLU()] for a,b in zip(layer_sizes[0:-1], layer_sizes[1:])])
    if primal:
        layers += [nn.Linear(layer_sizes[-1],nout), nn.ReLU(), nn.Linear(nout,nout), nn.Hardsigmoid()]
        return layers
    else:
        return layers+[nn.Linear(layer_sizes[-1],nout)]


class NNPrimalACOPFSolver(nn.Module):

    def __init__(self, data, args):

        super().__init__()

        self.data = data

        self.device = data.device

        self.npg = data.ngen

        self.nqg = data.ngen

        self.nvm = data.nbus

        self.nva = data.nbus - 1

        self.nbranch = data.nbranch

        self.ndva = self.nbranch

        self.xdim = 2 * data.nload  # pd and qd

        self.ydim = self.npg + self.nqg + self.nvm + self.ndva +self.nva

        self.nonslack_busidxs = torch.arange(data.nbus)

        self.nonslack_busidxs = self.nonslack_busidxs[self.nonslack_busidxs != data.slack_bus_idx]

        print("X dim:%d, Y dim:%d" % (self.xdim, self.ydim), flush=True)

        self.pgmin = data.pgmin;
        self.pgmax = data.pgmax

        self.qgmin = data.qgmin;
        self.qgmax = data.qgmax

        self.vmmin = data.vmmin;
        self.vmmax = data.vmmax

        self.dvamin = data.angmin;
        self.dvamax = data.angmax

        fraction = args['hiddenfrac']

        # combinded net

        nlayer = args['nlayer']

        nhidden = int(fraction * self.ydim)

        layer_sizes = [self.xdim]

        layer_sizes += nlayer * [nhidden]

        layers = reduce(operator.add,
                        [[nn.Linear(a, b), nn.ReLU()] for a, b in zip(layer_sizes[0:-1], layer_sizes[1:])])

        self.net = nn.Sequential(*layers)

        self.pg = nn.Sequential(nn.Linear(layer_sizes[-1], self.npg), nn.ReLU(), nn.Linear(self.npg, self.npg),
                                nn.Hardsigmoid())

        self.qg = nn.Sequential(nn.Linear(layer_sizes[-1], self.nqg), nn.ReLU(), nn.Linear(self.nqg, self.nqg),
                                nn.Hardsigmoid())

        self.vm = nn.Sequential(nn.Linear(layer_sizes[-1], self.nvm), nn.ReLU(), nn.Linear(self.nvm, self.nvm),
                                nn.Hardsigmoid())

        self.dva = nn.Sequential(nn.Linear(layer_sizes[-1], self.ndva), nn.ReLU(), nn.Linear(self.ndva, self.ndva),
                                 nn.Hardsigmoid())
        
        self.va = nn.Sequential(nn.Linear(layer_sizes[-1], self.nva+1), nn.ReLU(), nn.Linear(self.nva+1, self.nva+1),
                                nn.Hardsigmoid())
        
    def forward(self, x):

        pd = x["pd"]
        qd = x["qd"]
        pd = pd.to(self.device)
        qd = qd.to(self.device)
        x = torch.cat([pd, qd], dim=1)

        # combined net

        out = self.net(x)

        pg = self.pg(out)

        qg = self.qg(out)

        vm = self.vm(out)

        dva = self.dva(out)

        va = self.va(out) 

        pg = (self.pgmax - self.pgmin) * pg + self.pgmin

        qg = (self.qgmax - self.qgmin) * qg + self.qgmin

        pg_pad = self.data.pad(pg)

        qg_pad = self.data.pad(qg)

        pg_bus = pg_pad[:, self.data.bus_genidxs].sum(dim=2)

        qg_bus = qg_pad[:, self.data.bus_genidxs].sum(dim=2)

        vm = (self.vmmax - self.vmmin) * vm + self.vmmin

        dva = (self.dvamax - self.dvamin) * dva + self.dvamin

        flow = self.data.compute_flow(vm, dva)  # flow -- Ohm's law

        return {"pg": pg, "qg": qg, "pg_bus": pg_bus, "qg_bus": qg_bus,

                "vm": vm, "dva": dva, "flow": flow, "va": va}



class NNPrimalACOPFSolver_personalized(nn.Module):

    def __init__(self, data, nlayer, hiddenfrac):

        super().__init__()

        self.data = data

        self.device = data.device

        self.npg = data.ngen

        self.nqg = data.ngen

        self.nvm = data.nbus

        self.nva = data.nbus - 1

        self.nbranch = data.nbranch

        self.ndva = self.nbranch

        self.xdim = 2 * data.nload  # pd and qd

        self.ydim = self.npg + self.nqg + self.nvm + self.ndva +self.nva

        self.nonslack_busidxs = torch.arange(data.nbus)

        self.nonslack_busidxs = self.nonslack_busidxs[self.nonslack_busidxs != data.slack_bus_idx]

        print("X dim:%d, Y dim:%d" % (self.xdim, self.ydim), flush=True)

        self.pgmin = data.pgmin;
        self.pgmax = data.pgmax

        self.qgmin = data.qgmin;
        self.qgmax = data.qgmax

        self.vmmin = data.vmmin;
        self.vmmax = data.vmmax

        self.dvamin = data.angmin;
        self.dvamax = data.angmax

        fraction = hiddenfrac

        # combinded net

        nlayer = nlayer

        nhidden = int(fraction * self.ydim)

        layer_sizes = [self.xdim]

        layer_sizes += nlayer * [nhidden]

        layers = reduce(operator.add,
                        [[nn.Linear(a, b), nn.ReLU()] for a, b in zip(layer_sizes[0:-1], layer_sizes[1:])])

        self.net = nn.Sequential(*layers)

        self.pg = nn.Sequential(nn.Linear(layer_sizes[-1], self.npg), nn.ReLU(), nn.Linear(self.npg, self.npg),
                                nn.Hardsigmoid())

        self.qg = nn.Sequential(nn.Linear(layer_sizes[-1], self.nqg), nn.ReLU(), nn.Linear(self.nqg, self.nqg),
                                nn.Hardsigmoid())

        self.vm = nn.Sequential(nn.Linear(layer_sizes[-1], self.nvm), nn.ReLU(), nn.Linear(self.nvm, self.nvm),
                                nn.Hardsigmoid())

        self.dva = nn.Sequential(nn.Linear(layer_sizes[-1], self.ndva), nn.ReLU(), nn.Linear(self.ndva, self.ndva),
                                 nn.Hardsigmoid())
        
        self.va = nn.Sequential(nn.Linear(layer_sizes[-1], self.nva+1), nn.ReLU(), nn.Linear(self.nva+1, self.nva+1),
                                nn.Hardsigmoid())
        
    def forward(self, x):

        pd = x["pd"]
        qd = x["qd"]
        pd = pd.to(self.device)
        qd = qd.to(self.device)
        x = torch.cat([pd, qd], dim=1)

        # combined net

        out = self.net(x)

        pg = self.pg(out)

        qg = self.qg(out)

        vm = self.vm(out)

        dva = self.dva(out)

        va = self.va(out) 

        pg = (self.pgmax - self.pgmin) * pg + self.pgmin

        qg = (self.qgmax - self.qgmin) * qg + self.qgmin

        pg_pad = self.data.pad(pg)

        qg_pad = self.data.pad(qg)

        pg_bus = pg_pad[:, self.data.bus_genidxs].sum(dim=2)

        qg_bus = qg_pad[:, self.data.bus_genidxs].sum(dim=2)

        vm = (self.vmmax - self.vmmin) * vm + self.vmmin

        dva = (self.dvamax - self.dvamin) * dva + self.dvamin

        flow = self.data.compute_flow(vm, dva)  # flow -- Ohm's law

        return {"pg": pg, "qg": qg, "pg_bus": pg_bus, "qg_bus": qg_bus,

                "vm": vm, "dva": dva, "flow": flow, "va": va}
    


def load_acopf_data(args, current_path, device):
    datapath = current_path/"pglib_57"/"pascal"

    filesubpaths = {
        "acopf57":"pglib_opf_case57_ieee",
        "acopf118":"pglib_opf_case118_ieee",
        "acopf300": "pglib_opf_case300_ieee",
        "acopf57sad":"pglib_opf_case57_ieee__sad",
        "acopf118sad":"pglib_opf_case118_ieee__sad",
    }
    filepath = datapath#/filesubpaths[args['probtype']]
    #obj_scaler = args['objscaler'] if 'objscaler' in args.keys() and args['objscaler'] is not None else 1e5
    data = ACOPFProblem(filepath, args['probtype'], device, args, obj_scaler=1e0)
    args['nex'] = data.ndata
    args['nineq'] = data.nineq
    args['neq'] = data.neq
    print("Problem %s redefines the configuration--> nex:%d | nineq:%d | neq:%d"%(args['probtype'],data.ndata,data.nineq,data.neq),flush=True)
    print("   #bus:%d | #gen:%d | #load:%d | #branch:%d"%(data.nbus, data.ngen, data.nload, data.nbranch),flush=True)
    return data, args



def load_acopf_OOD_data(args, current_path, device, perc):
    perc = str(perc)+"%"
    datapath = current_path/"pglib_57"/"OOD"/perc
    filesubpaths = {
        "acopf57":"pglib_opf_case57_ieee",
        "acopf118":"pglib_opf_case118_ieee",
        "acopf300": "pglib_opf_case300_ieee",
        "acopf57sad":"pglib_opf_case57_ieee__sad",
        "acopf118sad":"pglib_opf_case118_ieee__sad",
    }
    filepath = datapath#/filesubpaths[args['probtype']]
    #obj_scaler = args['objscaler'] if 'objscaler' in args.keys() and args['objscaler'] is not None else 1e5
    data = ACOPFProblem(filepath, args['probtype'], device, args, obj_scaler=1e0)
    args['nex'] = data.ndata
    args['nineq'] = data.nineq
    args['neq'] = data.neq
    print("Problem %s redefines the configuration--> nex:%d | nineq:%d | neq:%d"%(args['probtype'],data.ndata,data.nineq,data.neq),flush=True)
    print("   #bus:%d | #gen:%d | #load:%d | #branch:%d"%(data.nbus, data.ngen, data.nload, data.nbranch),flush=True)
    return data, args