#!/usr/bin/env python
# coding: utf-8

# In[61]:


import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import control
import random
import networkx as nx
import pandas as pd
import matplotlib.pyplot as plt 
import seaborn as sns 
from pandas.plotting import table  # EDIT: see deprecation warnings below
from matplotlib import cm
import matplotlib
font = {'family' : 'normal','weight' : 'bold','size'   : 10}
matplotlib.rc('font', **font)


# In[62]:


def generate_B(n, m, M):
    #print(M)
    B = np.zeros((n,m))
    for i in range(n):
        for j in range(m):
            if i == M[j]:
                B[i][j] = 1
    #print(B)
    return B

def generate_C(n, l, O):
    #print(O)
    C = np.zeros((l,n))
    for j in range(l):
        for i in range(n):
            if i == O[j]:
                C[j][i] = 1
    #print(C)
    return C

def choose_graph(graph_mode):
    g = graph_mode
    if g == 'BA':
        graph = nx.barabasi_albert_graph(n, m_ba)
    elif g == 'ER':
        graph = nx.erdos_renyi_graph(n, p_er) 
    elif g == 'WS':
        graph = nx.watts_strogatz_graph(n, k_ws, p_ws)
    return graph


def weighted_matrix(n):
    sign = np.random.choice([-1,1],size = (n,n))
    W = np.multiply(sign,np.random.randint(low = 1, high = 5, size = (n,n)))
    return W

def graph_matrix(graph):
    n = graph.number_of_nodes()
    A = nx.adjacency_matrix(graph).todense()
    W = weighted_matrix(n)
    G = np.multiply(W,A)
    spectral_radius = np.max([np.abs(i) for i in np.linalg.eigvals(G)])
    G = G.T/(spectral_radius+0.01)

    return G


# # ADAM

# In[63]:


class Model(nn.Module):
    def __init__(self, n, C, B, theta):
        super().__init__()
        self.n = n
        self.C = torch.from_numpy(C).float()
        self.B = torch.from_numpy(B).float()
        self.G = nn.Parameter(torch.zeros(self.n, self.n), requires_grad=True)
        self.theta = theta
    def forward(self):
        output = torch.zeros((self.n - 2, self.C.shape[0], self.B.shape[1]))
        tmp = torch.eye(self.n)
        l = output.shape[0]
        for i in range(l):
            tmp = tmp @ self.G
            output[i] = self.C @ tmp @ self.B
        return output + self.theta*torch.norm(self.G, p=1)

def adam(n, C, B, M, theta):
    model = Model(n, C, B,theta)
    criterion = nn.MSELoss(reduction='sum')
    #optimizer = optim.Adam(model.parameters(), lr=0.001)

    optimizer = optim.Adam(model.parameters(), lr=0.0001)

    epoch = 50000
    for i in range(epoch):
        output = model.forward()
        output = output.view(output.shape[0], -1)
        M = M.view(M.shape[0], -1)

        loss = criterion(output, M)
        if i % 5000 == 0:
            print("Epoch: [{}/{}]\tLoss: {}".format(i, epoch, loss.detach().numpy()))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    G_learn = model.G.cpu().detach().numpy()
    return G_learn


# # Data Synthesis

# In[64]:


def data_synthesis(n, l, m, G_truth):
    tmp = np.identity(n)
    M = np.zeros((n-2, l, m))
    for i in range(n-2):
        tmp = tmp @ G_truth
        M[i] = C@tmp@B 
    M = torch.from_numpy(M).float()
    return M


# # Testing

# In[65]:


M = [0, 1, 5, 6, 7, 8, 9]
O = [0, 1, 2, 3, 4, 5, 6 ]

n, l, m = 10, len(O), len(M)

B = generate_B(n, m, M)
C = generate_C(n, l, O)

graph_modes = ['ER','WS','BA']
graph_mode = graph_modes[1]
p_er = 0.2
k_ws = 2
p_ws = 0.2
m_ba = 1
df_graph = []
df_error = []
df_method = []
num_simulations = 2
for graph_mode in graph_modes:
    cnt = 1
    while cnt <= num_simulations:
        graph = choose_graph(graph_mode)
        G_truth = graph_matrix(graph)
        obs_rank = np.linalg.matrix_rank(control.obsv(G_truth,C))
        if obs_rank == n:
            print('Processing ', cnt, 'th ',graph_mode,' graph.')
            M = data_synthesis( n, l, m, G_truth)
            a2_theta = 0
            G_a2 = adam(n, C, B, M, a2_theta)
            a2_error = np.linalg.norm(G_truth-G_a2,'fro')/np.linalg.norm(G_truth,'fro')
            df_graph.append(graph_mode)
            df_error.append(a2_error)
            df_method.append('SSubI')
            print('Relative Error of SSubI of ',cnt,'th',graph_mode,' graph: ', a2_error)
            a2_l1_theta = 0.001
            G_a2_l1 = adam(n, C, B, M, a2_l1_theta)
            a2_l1_error = np.linalg.norm(G_truth-G_a2_l1,'fro')/np.linalg.norm(G_truth,'fro')
            df_graph.append(graph_mode)
            df_error.append(a2_l1_error)
            df_method.append('SSSubI')
            print('Relative Error of SSSubI of ',cnt,'th',graph_mode,' graph: ', a2_l1_error)
            cnt = cnt +1
        else:
            pass
print(len(df_graph),len(df_error),len(df_method)) 
columns = {'Network type':df_graph,'error':df_error,'method':df_method}
df_performance = pd.DataFrame(columns)
df_performance.to_csv('./partial_action_performance.csv')


# In[66]:


def plotting_function():
    dd = pd.read_csv('partial_action_performance.csv')
    plt.figure(figsize=(6,5))
    x_name = 'Network type'
    bp = sns.boxplot(x=x_name,y='error',data=dd,hue='method', showfliers=False,showmeans = False, palette=sns.color_palette("tab10"),linewidth=1,width = 0.3)
    plt.ylabel('Relative error', fontsize = 14, fontweight = 'bold')
    #plt.ylim(0,1)
    plt.legend(loc = 0, prop={'size': 10}) # loc = 3/4
    #plt.xticks(range(4), [0.2, 0.4, 0.6, 0.8])
    plt.xlabel(x_name, fontsize = 14, fontweight = 'bold')      
    plt.tight_layout()
    plt.savefig('partial_action_performance.png',dpi=150)

plotting_function()


# In[ ]:




