import copy
import numpy as np
from mat4py import loadmat
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from env import Frequency

## Load the network data
data=loadmat('data/IEEE_39bus_Kron.mat')
K_EN=data['Kron_39bus']['K']
K_EN=np.asarray(K_EN,dtype=np.float32)

H=data['Kron_39bus']['H']
H=np.asarray(H,dtype=np.float32)

Damp=data['Kron_39bus']['D']
Damp=np.asarray(Damp,dtype=np.float32)

omega_R=data['Kron_39bus']['omega_R']

A_EN=data['Kron_39bus']['A']
A_EN=np.asarray(A_EN,dtype=np.float32)

gamma=data['Kron_39bus']['gamma']
gamma=np.asarray(gamma,dtype=np.float32)

## The network parameter
dim_action=10
dim_state=2*dim_action
delta_t=0.01
M=H.reshape(dim_action)*2/omega_R*2*np.pi
D=np.zeros(dim_action,dtype=np.float32)
D[0]=2*590/100
D[1:8]=2*865/100
D[8:10]=2*911/100
D=D/omega_R*2*np.pi
KE=K_EN
Penalty_action=0.02*0.2
Pm = np.array([[-0.19983394, -0.25653884, -0.25191885, -0.10242008, -0.34510365,
         0.23206371,  0.4404325 ,  0.5896664 ,  0.26257738, -0.36892462]],dtype = np.float32)

max_action = np.array([[0.19606592, 0.2190382 , 0.22375287, 0.0975513 , 0.29071101,
        0.22091283, 0.38759459, 0.56512538, 0.24151538, 0.29821917]],dtype = np.float32)

equilibrium_init = np.array([[ -0.05420687, -0.07780334, -0.07351729, -0.05827823, -0.09359571,
        -0.02447385, -0.00783582,  0.00259523, -0.0162409 , -0.06477749,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.       ]],dtype = np.float32)

Env=Frequency(Pm,M,D,KE,delta_t,max_action,dim_action,Penalty_action)
action_units = dim_action

class RNNCell(nn.Module):
    def __init__(self,units,action_units,internal_units,Env,batchsize):
        super(RNNCell,self).__init__()
        self.units=units
        self.action_units=action_units
        self.state_size=action_units*2
        self.internal_units=internal_units
        self.batchsize=batchsize
        self.state_transfer1=torch.FloatTensor(Env.state_transfer1)
        self.state_transferF=torch.FloatTensor(Env.state_transferF)
        self.state_transfer2=torch.FloatTensor(Env.state_transfer2)
        self.state_transfer3=torch.FloatTensor(Env.state_transfer3)
        self.state_transfer4=torch.FloatTensor(Env.state_transfer4)
        self.state_transfer3_Pm=torch.FloatTensor(Env.state_transfer3_Pm)
        self.select_add_w=torch.FloatTensor(Env.select_add_w)
        self.select_w=torch.FloatTensor(Env.select_w)
        self.select_delta=torch.FloatTensor(Env.select_delta)
        self.max_action=torch.FloatTensor(Env.max_action)
        self.Multiply_ones=torch.tile(torch.ones((action_units,action_units),dtype=torch.float32),(batchsize,1,1))
        self.w_recover=torch.triu(-torch.ones((internal_units,internal_units)),diagonal=0)-torch.triu(-torch.ones((internal_units,internal_units)),diagonal=2)+\
            2*torch.eye(internal_units,dtype=torch.float32)
        self.b_recover=torch.triu(torch.ones((internal_units,internal_units)),diagonal=1)
        self.ones_frequency=torch.ones((action_units,internal_units),dtype=torch.float32)

        ## Define the neural network for control
        self.agent_list = nn.ModuleList()
        for k in range(action_units):
            agent_k = nn.Sequential(
                nn.Linear(1, self.internal_units, bias=True), nn.ReLU(),
                nn.Linear(self.internal_units, self.internal_units, bias=True), nn.ReLU(),
                nn.Linear(self.internal_units, 1, bias=True)
            )
            self.agent_list.append(agent_k)

    def forward(self,inputs,states):
        prev_output=states[0]
        frequency_w=prev_output[:,self.action_units:]
        action_list = []
        for k in range(self.action_units):
            frequency_w_bus_k = torch.unsqueeze(frequency_w[:, k], 1)
            action_k = self.agent_list[k](frequency_w_bus_k)
            action_list.append(action_k)
        action = torch.hstack(action_list)
        action = torch.clamp(action, -self.max_action, self.max_action)

        # integrate the state transition dynamics
        new_state=torch.mm(prev_output,self.state_transfer1)+\
            torch.mm(torch.sum(torch.sin(torch.matmul(torch.diag_embed(torch.mm(prev_output,self.select_delta)),torch.ones((self.action_units,self.action_units),dtype=torch.float32))-\
                                  torch.matmul(self.Multiply_ones,torch.diag_embed(torch.mm(prev_output,self.select_delta))))\
                        *self.state_transferF,dim=2),self.state_transfer2)\
        +self.state_transfer3+torch.mm(action,self.state_transfer4)\
        +inputs@self.state_transfer3_Pm

        loss0=torch.mm(torch.pow(new_state,2),self.select_add_w)
        frequency=torch.mm(new_state,self.select_w)

        return [loss0,frequency,action],[new_state]

    def get_action(self,state):
        prev_state=state
        frequency_w=prev_state[:,self.action_units:]

        action_list = []
        for k in range(self.action_units):
            frequency_w_bus_k = torch.unsqueeze(frequency_w[:, k], 1)
            action_k = self.agent_list[k](frequency_w_bus_k)
            action_list.append(action_k)
        action = torch.hstack(action_list)
        action = torch.clamp(action, -self.max_action, self.max_action)
        return action.detach()


## load the model
units = action_units #dimension of each state
internal_units = 32 # demension of the neural network for control policy
T = 200  #Total period considered
Batch_num = 600 # number of batch in each episode
model=RNNCell(units,action_units,internal_units,Env,Batch_num)
net_dict=torch.load('model.pth')
model.load_state_dict(net_dict)
# model.print_parameter()

## Compare the trajectory
Trajectory_RNN=[]

init_state=equilibrium_init
s=init_state.copy().astype(np.float32)
Env.set_state(s)
Trajectory_RNN.append(s)
SimulationLength=700
Record_u=[]
Record_Loss=[]
Loss_RNN=0
Loss_RNN_discounted=0
Pm_init=Pm.copy()
Pm1=Pm_init.copy().astype(np.float32)
Pm2=(Pm_init.copy()).astype(np.float32)
gen_id=[2,3,8]
Pm2[0][gen_id]=0

for i in range(SimulationLength):
    if i<int(50) or i>int(700):
        Pm_change=Pm1.copy()
    if i>=int(50) and i<int(700):
        Pm_change=Pm2.copy()
    u=model.get_action(torch.FloatTensor(s))
    u=np.squeeze(np.asarray(u))

    loss_action=0.1*Env.Penalty_action*np.sum(np.abs(u))
    next_s,r=Env.step(u,Pm_change)

    Loss_RNN_discounted += r
    Loss_RNN += r+loss_action
    s = next_s
    Trajectory_RNN.append(s)
    Record_u.append(u)
    Record_Loss.append(np.squeeze(r))

print('The control loss by trained controller:',Loss_RNN)

Trajectory_RNN = np.squeeze(np.asarray(Trajectory_RNN))
plt.figure(figsize=(11, 8), dpi=100)
TimeRecord = np.arange(1, SimulationLength + 1)
TimeRecord = Env.delta_t * TimeRecord

plt.subplot(2, 2, 1)
plt.plot(TimeRecord, Record_Loss)
plt.xlabel('time(s)')
plt.ylabel('Loss')

plt.subplot(2, 2, 2)
plt.plot(TimeRecord, Record_u)
plt.xlabel('time(s)')
plt.ylabel('Action')

plt.subplot(2, 2, 3)
TimeRecord = np.arange(1, SimulationLength + 2)
TimeRecord = Env.delta_t * TimeRecord
plt.plot(TimeRecord, Trajectory_RNN[:, 0:action_units])
plt.xlabel('time(s)')
plt.ylabel('delta(rad)')

plt.subplot(2, 2, 4)
plt.plot(TimeRecord, Trajectory_RNN[:, action_units:action_units * 2])
plt.xlabel('time(s)')
plt.ylabel('w (Hz)')
plt.show()

# Save the control trajectory
np.save('action_constrained_learn.npy',Record_u)
np.save('trajectory_constrained_learn.npy',Trajectory_RNN)

################################################################################
## The linear drop control
def np_relu(x):
    return np.maximum(0,x)

def Action_linear(state,linear_coff,env):
    action_nonconstrain=(state@env.select_w)*linear_coff
    action=env.max_action-np_relu(env.max_action-action_nonconstrain)+np_relu(-env.max_action-action_nonconstrain)
    return action

linear_coff=np.array([[ 2.2347659,   2.720281 ,   2.4595585,  43.59709  ,   5.7675405,
          2.3880444, 156.02136  ,  12.0558605,   2.8042254,  25.922562  ]], dtype=np.float32)

s_linear=init_state.copy().astype(np.float32)
Loss_Linear=0
Loss_Linear_discounted=0
Trajectory_Linear=[]
Trajectory_Linear.append(s_linear)
Record_u_linear=[]
Record_Loss_linear=[]
for i in range(SimulationLength):
    if i<int(50) or i>int(700):
        Pm_change=Pm1.copy()
    if i>=int(50) and i<int(700):
        Pm_change=Pm2.copy()
    u_linear=np.squeeze(Action_linear(s_linear,linear_coff,Env))
    next_s_linear,r_linear=Env.step(u_linear,Pm_change)

    Loss_Linear_discounted+=r_linear
    Loss_Linear+=r_linear
    s_linear=next_s_linear
    Trajectory_Linear.append(s_linear)
    Record_u_linear.append(u_linear)
    Record_Loss_linear.append(np.squeeze(r_linear))

print('The control loss by Linear controller:',Loss_Linear)
Trajectory_Linear = np.squeeze(np.asarray(Trajectory_Linear))
plt.figure(figsize=(11, 8), dpi=100)
TimeRecord = np.arange(1, SimulationLength + 1)
TimeRecord = Env.delta_t * TimeRecord

# plt.subplot(2, 2, 1)
# plt.plot(TimeRecord, Record_Loss_linear)
# plt.xlabel('time(s)')
# plt.ylabel('Loss')
#
# plt.subplot(2, 2, 2)
# plt.plot(TimeRecord, Record_u_linear)
# plt.xlabel('time(s)')
# plt.ylabel('Action')
#
# plt.subplot(2, 2, 3)
# TimeRecord = np.arange(1, SimulationLength + 2)
# TimeRecord = Env.delta_t * TimeRecord
# plt.plot(TimeRecord, Trajectory_Linear[:, 0:action_units])
# plt.xlabel('time(s)')
# plt.ylabel('delta(rad)')
#
# plt.subplot(2, 2, 4)
# plt.plot(TimeRecord, Trajectory_Linear[:, action_units:action_units * 2])
# plt.xlabel('time(s)')
# plt.ylabel('w (Hz)')
# plt.show()

## Plot the action
initial_state1=np.random.uniform(0.0,0.3,(1,action_units))
initial_state2=np.random.uniform(-0.03,0.03,(1,action_units))
s_concate0=np.hstack((initial_state1,initial_state2)).astype(np.float32)
state_d=np.float32(0.2)

state_w=np.arange(-0.1,0.1,0.002,dtype=np.float32)

action_dw=np.zeros(len(state_w))
action_dw_linear=np.zeros(len(state_w))

fig=plt.figure(figsize=(15,9),dpi=100)
action_list=[]
plt.subplot(2,2,1)
for idx_plot in range(action_units):
    gen_idx=idx_plot
    s_concate=copy.deepcopy(s_concate0)
    for j in range(len(state_w)):
        s_concate[0,action_units+gen_idx]=state_w[j]
        u = model.get_action(torch.FloatTensor(s_concate))
        u=np.asarray(u)
        u_linear=Action_linear(s_concate,linear_coff,Env)
        action_dw[j]=u[0][gen_idx]
        action_dw_linear[j]=u_linear[0][gen_idx]
    plt.subplot(int(np.ceil(action_units/3)),3,idx_plot+1)
    plt.plot(state_w,action_dw,label='Ours')
    action_list.append(action_dw.copy())
    plt.plot(state_w,action_dw_linear,label='Linear')
    plt.scatter(0,0)
    plt.title('gen'+str(idx_plot+1))
    plt.xlabel('$\omega (Hz)$')
    plt.ylabel('u (p.u.)')
    plt.legend(bbox_to_anchor=(0.05,1),loc='upper left',borderaxespad=0.)
fig.tight_layout()
plt.show()

np.save('plot_data/action_plot_constrain_learn.npy',action_list)
