from plot_traj_ineq import *
import numpy as np
from ZO_methods_ineq import *
import scipy.io

def simulate_trajectory(K, b, A, B, d, x0, T,x_set):
    n = x0.shape[0]
    
    # Initialize state and control trajectories
    x_t = x0
    u_trajectory = []
    x_trajectory = [x_t]
    
    # Initialize accumulators for the performance metrics
    u_squared_sum = 0
    max_x_penalty_sum = 0
    
    for t in range(T):
        # Calculate control input u_t
        u_t = K @ x_t + b
        
        # Calculate next state x_{t+1}
        x_t1 = A @ x_t + B @ u_t + d
        
        # Store trajectories
        u_trajectory.append(u_t)
        x_trajectory.append(x_t1)
        
        # Accumulate performance metrics
        u_squared_sum += u_t @ u_t.T
        max_x_penalty_sum += np.sum(np.maximum(x_t - x_set, 0) ** 2)
        
        # Update state
        x_t = x_t1
    
    # Compute the averages
    avg_u_squared = u_squared_sum / T/n
    avg_max_x_penalty = max_x_penalty_sum / T/n
    
    return avg_u_squared, avg_max_x_penalty, np.vstack(x_trajectory), np.vstack(u_trajectory)


def f(x):
    n = int(x.shape[0]/2)
    K = x[0:n]
    b = x[n:2*n]
    K = np.diag(K)
    return simulate_trajectory(K, b, A0, B0, d0, x0, T0, x_set)[0]

def h(x):
    n = int(x.shape[0]/2)
    K = x[0:n]
    b = x[n:2*n]
    K = np.diag(K)
    return simulate_trajectory(K, b, A0, B0, d0, x0, T0, x_set)[1] - 1.5

if __name__ == "__main__":
    # Example nonconvex contraints
    mat_data = scipy.io.loadmat('Building-2x2x5-room.mat')
    A = mat_data['A']
    B = mat_data['B']
    in_door_gain = mat_data['in_door_gain'].squeeze()
    x_set = mat_data['x_set'][0][0]
    xo = mat_data['xo'][0][0]
    N = np.shape(A)[0]
    C = 200
    delta_t = 40
    A0 = np.eye(N) + 1/C*delta_t*A
    B0 = 1/C*delta_t*B;
    d0 = 1/C*delta_t*(xo+in_door_gain)
    T0 = 100
    x0 = xo*np.ones(N,)
    # K = -0.3*np.eye(N)
    # b = -0.3*np.ones(N,)
    def simulate_trajectory(K, b, A, B, d, x0, T,x_set):
        n = x0.shape[0]
        
        # Initialize state and control trajectories
        x_t = x0
        u_trajectory = []
        x_trajectory = [x_t]
        
        # Initialize accumulators for the performance metrics
        u_squared_sum = 0
        max_x_penalty_sum = 0
        
        for t in range(T):
            # Calculate control input u_t
            u_t = K @ x_t + b
            
            # Calculate next state x_{t+1}
            x_t1 = A @ x_t + B @ u_t + d
            
            # Store trajectories
            u_trajectory.append(u_t)
            x_trajectory.append(x_t1)
            
            # Accumulate performance metrics
            u_squared_sum += u_t @ u_t.T
            max_x_penalty_sum += np.sum(np.maximum(x_t - x_set, 0) ** 2)
            
            # Update state
            x_t = x_t1
        
        # Compute the averages
        avg_u_squared = u_squared_sum / T/n
        avg_max_x_penalty = max_x_penalty_sum / T/n
        
        return avg_u_squared, avg_max_x_penalty, np.vstack(x_trajectory), np.vstack(u_trajectory)


    def f(x):
        n = int(x.shape[0]/2)
        K = x[0:n]
        b = x[n:2*n]
        K = np.diag(K)
        return simulate_trajectory(K, b, A0, B0, d0, x0, T0, x_set)[0]

    def h(x):
        n = int(x.shape[0]/2)
        K = x[0:n]
        b = x[n:2*n]
        K = np.diag(K)
        return simulate_trajectory(K, b, A0, B0, d0, x0, T0, x_set)[1] - 1.5


    initial_guess = -0.3*np.ones(2*N)

    best_point4,best_value4, history4 = ZOGDA(f, h, initial_guess, \
        learning_rate_x=1e-3, learning_rate_lambda=0.03, r=0.1, TB=8, max_iter=1500)

    best_point5,best_value5, history5 = ConEx(f, h, initial_guess, \
                                            learning_rate_x=3e-3,learning_rate_lambda=0.005,r=0.01, TB=8, max_iter=1500,theta=0.1)
    print(0.001/2/N)
    print(20*4*N*N)
    best_point,best_value, history = ZO_baseline(f, h, initial_guess, \
                                            learning_rate=2.5e-5, r=0.1, epsilon=1e-3, TB=8, k=32000, max_iter=1500)
    best_point2,best_value2, history2 = ZOFL(f, h, initial_guess, \
                                            learning_rate=2.5e-5, r=0.1, epsilon=1e-3, TB=8, k=3200, max_iter=1500)
    best_point3,best_value3, history3 = ZOFL_midpoint(f, h, initial_guess, \
                                            learning_rate=2.5e-5, r=0.1, epsilon=1e-3, TB=8, k=3200, max_iter=1500)
    np.savez('test_thermal_control.npz', history=history, history2=history2, history3=history3, history4=history4, history5=history5)
    plot_optimization_history(history4, history5, history, history2,history3, \
                          legends = ['ZOGDA', 'SZO-ConEx', 'ZO-Baseline','ZOFL','ZOFL-midpoint'],linewidth=3,savename='comparison_ineq.png',colors = ['#1f77b4', '#2ca02c', '#ff7f0e', '#d62728', '#9467bd']
)