import numpy as np
from mat4py import loadmat
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from env import Frequency

## Load the network data
data=loadmat('data/IEEE_39bus_Kron.mat')
K_EN=data['Kron_39bus']['K']
K_EN=np.asarray(K_EN,dtype=np.float32)

H=data['Kron_39bus']['H']
H=np.asarray(H,dtype=np.float32)

Damp=data['Kron_39bus']['D']
Damp=np.asarray(Damp,dtype=np.float32)

omega_R=data['Kron_39bus']['omega_R']

A_EN=data['Kron_39bus']['A']
A_EN=np.asarray(A_EN,dtype=np.float32)

gamma=data['Kron_39bus']['gamma']
gamma=np.asarray(gamma,dtype=np.float32)

## The network parameter
dim_action=10
dim_state=2*dim_action
delta_t=0.01
M=H.reshape(dim_action)*2/omega_R*2*np.pi
D=np.zeros(dim_action,dtype=np.float32)
D[0]=2*590/100
D[1:8]=2*865/100
D[8:10]=2*911/100
D=D/omega_R*2*np.pi
KE=K_EN
Penalty_action=0.02*0.2
Pm = np.array([[-0.19983394, -0.25653884, -0.25191885, -0.10242008, -0.34510365,
         0.23206371,  0.4404325 ,  0.5896664 ,  0.26257738, -0.36892462]],dtype = np.float32)

max_action = np.array([[0.19606592, 0.2190382 , 0.22375287, 0.0975513 , 0.29071101,
        0.22091283, 0.38759459, 0.56512538, 0.24151538, 0.29821917]],dtype = np.float32)

equilibrium_init = np.array([[ -0.05420687, -0.07780334, -0.07351729, -0.05827823, -0.09359571,
        -0.02447385, -0.00783582,  0.00259523, -0.0162409 , -0.06477749,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.       ]],dtype = np.float32)

Env=Frequency(Pm,M,D,KE,delta_t,max_action,dim_action,Penalty_action)

###################### Define the minimal RNN cell ###########################
class RNNCell(nn.Module):
    def __init__(self,units,action_units,internal_units,Env,batchsize):
        super(RNNCell,self).__init__()
        self.units=units
        self.action_units=action_units
        self.state_size=action_units*2
        self.internal_units=internal_units
        self.batchsize=batchsize
        self.state_transfer1=torch.FloatTensor(Env.state_transfer1)
        self.state_transferF=torch.FloatTensor(Env.state_transferF)
        self.state_transfer2=torch.FloatTensor(Env.state_transfer2)
        self.state_transfer3=torch.FloatTensor(Env.state_transfer3)
        self.state_transfer4=torch.FloatTensor(Env.state_transfer4)
        self.state_transfer3_Pm=torch.FloatTensor(Env.state_transfer3_Pm)
        self.select_add_w=torch.FloatTensor(Env.select_add_w)
        self.select_w=torch.FloatTensor(Env.select_w)
        self.select_delta=torch.FloatTensor(Env.select_delta)
        self.max_action=torch.FloatTensor(Env.max_action)
        self.Multiply_ones=torch.tile(torch.ones((action_units,action_units),dtype=torch.float32),(batchsize,1,1))
        self.w_recover=torch.triu(-torch.ones((internal_units,internal_units)),diagonal=0)-torch.triu(-torch.ones((internal_units,internal_units)),diagonal=2)+\
            2*torch.eye(internal_units,dtype=torch.float32)
        self.b_recover=torch.triu(torch.ones((internal_units,internal_units)),diagonal=1)
        self.ones_frequency=torch.ones((action_units,internal_units),dtype=torch.float32)

        ## Define the neural network for control
        self.agent_list=nn.ModuleList()
        for k in range(action_units):
            agent_k=nn.Sequential(
                nn.Linear(1,self.internal_units,bias=True),nn.ReLU(),
                nn.Linear(self.internal_units,self.internal_units,bias=True),nn.ReLU(),
                nn.Linear(self.internal_units,1,bias=True)
            )
            self.agent_list.append(agent_k)

        for m in self.modules():
            if isinstance(m,nn.Linear):
                # m.weight.data.normal_(0,0.2)
                m.bias.data.zero_()
        print('Initialization!')

    def forward(self,inputs,states):
        prev_output=states[0]
        frequency_w=prev_output[:,self.action_units:]
        # compute the action
        action_list=[]
        for k in range(self.action_units):
            frequency_w_bus_k=torch.unsqueeze(frequency_w[:,k],1)
            action_k=self.agent_list[k](frequency_w_bus_k)
            action_list.append(action_k)
        action=torch.hstack(action_list)
        action=torch.clamp(action,-self.max_action,self.max_action)

        # integrate the state transition dynamics
        new_state=torch.mm(prev_output,self.state_transfer1)+\
            torch.mm(torch.sum(torch.sin(torch.matmul(torch.diag_embed(torch.mm(prev_output,self.select_delta)),torch.ones((self.action_units,self.action_units),dtype=torch.float32))-\
                                  torch.matmul(self.Multiply_ones,torch.diag_embed(torch.mm(prev_output,self.select_delta))))\
                        *self.state_transferF,dim=2),self.state_transfer2)\
        +self.state_transfer3+torch.mm(action,self.state_transfer4)\
        +inputs@self.state_transfer3_Pm

        loss0=torch.mm(torch.pow(new_state,2),self.select_add_w)
        frequency=torch.mm(new_state,self.select_w)

        return [loss0,frequency,action],[new_state]

    def reg_forward(self,feature_num,num=512):
        in_list=[]
        out_list=[]

        input_feature=(torch.rand(num,feature_num)-0.5)*0.3
        input_mono=input_feature
        input_mono.requires_grad=True
        in_list.append(input_mono)

        action_list = []
        for k in range(self.action_units):
            frequency_w_bus_k = torch.unsqueeze(input_mono[:, k], 1)
            action_k = self.agent_list[k](frequency_w_bus_k)
            action_list.append(action_k)
        action = torch.hstack(action_list)
        # action = torch.clamp(action, -self.max_action, self.max_action)

        out_list.append(action)
        return in_list,out_list


## This class defines the recurrent implementation
class RNNWrap(nn.Module):
    def __init__(self,Cell,return_sequence=True):
        super(RNNWrap,self).__init__()
        self.cell=Cell
        self.batch_size=Cell.batchsize
        self.initial_state=None

    def forward(self,inputs,state=None):
        inputs=torch.FloatTensor(inputs)
        state=torch.FloatTensor([self.initial_state])
        batch_size,seq_len,_=inputs.size()
        output_loss=[]
        output_freq=[]
        output_actn=[]
        for t in range(seq_len):
            out,state=self.cell(inputs[:,t,:],state)
            output_loss.append(out[0])
            output_freq.append(out[1])
            output_actn.append(out[2])

        output_loss=torch.stack(output_loss)
        output_freq=torch.stack(output_freq)
        output_actn=torch.stack(output_actn)

        # reshape as batch first
        output_loss=output_loss.transpose(0,1)
        output_freq=output_freq.transpose(0,1)
        output_actn=output_actn.transpose(0,1)

        return [output_loss,output_freq,output_actn]

    def reset_states(self,initial_state):
        self.initial_state=initial_state


def compute_dual_loss(in_list,out_list,mu_lower,t_lower):
    Lagragian_loss=0
    in_list=in_list[0]
    out_list=out_list[0]
    length=in_list.shape[1]
    t_lower=torch.FloatTensor(t_lower)
    mu_lower_batch=np.tile(mu_lower,(128,1))
    mu_lower_batch=torch.FloatTensor(mu_lower_batch)

    grad_mat=torch.zeros([128,dim_action],dtype=torch.float32)
    for i in range(length):
        xx=in_list
        yy=out_list[:,i]
        grad_input=torch.autograd.grad(torch.sum(yy),xx,create_graph=True,allow_unused=True)[0]

        # record the gradient for dual update
        grad_mat[:,i]=grad_input[:,i]
    indicator=t_lower-grad_mat ## If you want to change the bound for the gradient, plus a small positive term here.
    indicator[indicator<0]=0
    loss_dim=torch.sum(indicator*mu_lower_batch,dim=1)
    Lagragian_loss+=loss_dim
    return Lagragian_loss,grad_mat

################################### The start of training #########################################
episodes  = 600 # total number of iterations to update weights
action_units = dim_action
units = action_units #dimension of each state
internal_units = 32 # demension of the neural network for control policy
T = 200  #Total period considered
Batch_num = 600 # number of batch in each episode
cell = RNNCell(units,action_units,internal_units,Env,Batch_num)
model=RNNWrap(cell)

optimizer=torch.optim.Adam(model.cell.parameters(),lr=5e-3)

Loss_record=[]
Loss_record_orginal=[]
num_gen_step=3
Percent_step_change=1
range_step_change=1
PrintUpdate=1

## define the dual variables and auxiliary variables
mu_lower=np.zeros((dim_action,),dtype=np.float32)
t_lower=np.ones((dim_action,),dtype=np.float32)*0.0001
ZEROS=[0]*128
alpha=0.1
lr_mu=10
mu_lower_rec=[]

for i in range(0,episodes):
    initial_state=np.zeros((Batch_num,action_units*2))+equilibrium_init
    Pm_change=np.zeros((Batch_num,T,units))
    for gen_interupt in range(0,num_gen_step):
        idx_gen_deviation=np.random.randint(0,action_units,Batch_num*Percent_step_change)
        idx_batch_deviation=np.random.randint(0,Batch_num,Batch_num*Percent_step_change)
        slot_start_deviation=np.random.randint(0,T/2,Batch_num*Percent_step_change)
        step_change=np.random.uniform(-1,1,(Batch_num*Percent_step_change))*range_step_change
        for t_interupt in range(0,T):
            Pm_change[idx_batch_deviation,t_interupt,idx_gen_deviation]=(slot_start_deviation>=t_interupt)*step_change

    # reset the state
    model.reset_states(initial_state)
    # The forward
    [loss0,frequency,action]=model(Pm_change)

    # Apply the constrained learning
    in_list,out_list=model.cell.reg_forward(feature_num=dim_action,num=128)
    Lagragian_loss,grad_mat=compute_dual_loss(in_list,out_list,mu_lower,t_lower)
    Lagragian_loss=Lagragian_loss.mean()
    grad_mat=grad_mat.detach().cpu().numpy()

    loss_action=0.1*Env.Penalty_action*torch.sum(torch.abs(action))/Batch_num
    # loss_action = Env.Penalty_action*torch.sum(torch.abs(action)) / Batch_num
    loss_freq=torch.sum(torch.max(torch.abs(frequency),1).values)/Batch_num
    loss=loss_action+loss_freq+Lagragian_loss
    # loss=loss_action+loss_freq

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    ## The dual update
    for j in range(mu_lower.shape[0]):
        grad_mu_lower=np.mean(np.max((t_lower[j]-grad_mat[:,j],ZEROS),axis=0)-alpha*t_lower[j])

        mu_lower[j]=max(mu_lower[j]+lr_mu*(grad_mu_lower-0.0*mu_lower[j]),0)

    # record the dual variables
    mu_lower_rec.append(mu_lower.copy())

    Loss_record.append(loss.detach().numpy())
    Loss_record_orginal.append((loss_action+loss_freq).detach().numpy())
    if i%(PrintUpdate)==0:
        print('Episode: {}, Total loss: {}, Frequency loss: {}, objective loss: {}'.format(i,loss,loss_freq,loss_freq+loss_action))

# save the model
torch.save(model.cell.state_dict(),"./model_unconstraint.pth",_use_new_zipfile_serialization=False)

plt.figure()
plt.plot(Loss_record)
plt.xlabel('episode')
plt.ylabel('Loss')
plt.show()

np.save('loss_record_constrained_learn.npy',Loss_record_orginal)

mu_lower_rec=np.asarray(mu_lower_rec)
timeLine=list(range(mu_lower_rec.shape[0]))
for i in range(mu_lower_rec.shape[1]):
    plt.plot(timeLine,mu_lower_rec[:,i])
plt.xlabel('training episode')
plt.ylabel('dual variables')
plt.grid(True)
plt.savefig('trend_dual_variables.png')
plt.show()

# summation of negative gradients
in_list,out_list=model.cell.reg_forward(feature_num=dim_action,num=128)
_,grad_mat1=compute_dual_loss(in_list,out_list,mu_lower,t_lower)
grad_mat1[grad_mat1>0]=0
grad_mat[grad_mat>0]=0
print('The sum of negative gradient: {} and in training: {}'.format(torch.sum(grad_mat1),np.sum(grad_mat)))

