from naive_solver import Solver as NaiveSolver
from ada_solver import Solver as AdaSolver
from solver import Solver

from matplotlib.patches import ConnectionPatch
import matplotlib.pyplot as plt
from copy import deepcopy
import numpy as np

EPS = 1e-8

def objective(x, y):
    theta = np.pi/6.0
    rot_mat = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
    result = rot_mat.T@np.array([x, y])
    x_hat, y_hat = result[0], result[1]
    obj = -np.sqrt(((x_hat + 1.0)/2.0)**2 + (y_hat - 2.0)**2)
    return obj

def cost(x, y):
    return -x

def cost2(x, y):
    return x - 2*y


H_mat = np.array([
    [1.0, 0.0],
    [0.0, 1.0]
])
max_kl = 0.5
num_costs = 2
cost_functions = [cost, cost2]
limit_values = [0.0, 0.0]
slack_decay = 0.9
solver = Solver(num_costs, limit_values, H_mat, max_kl)
n_solver = NaiveSolver(num_costs, limit_values, H_mat, max_kl)
ada_solver = AdaSolver(num_costs, limit_values, H_mat, max_kl, slack_decay)
init_state = np.array([-2.5, -3.0])
solution_list, solution_list2, solution_list3 = [init_state], [init_state], [init_state]
state, state2, state3 = deepcopy(init_state), deepcopy(init_state), deepcopy(init_state)




# solve
for _ in range(10):
    new_state, info = solver.solve(state, objective, cost_functions)
    new_state2, info2 = n_solver.solve(state2, objective, cost_functions)
    new_state3, info3 = ada_solver.solve(state3, objective, cost_functions)
    print(np.linalg.norm(state3 - new_state3, ord=2))
    solution_list.append(new_state)
    solution_list2.append(new_state2)
    solution_list3.append(new_state3)
    state = deepcopy(new_state)
    state2 = deepcopy(new_state2)
    state3 = deepcopy(new_state3)



# plot objective
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
X = np.arange(-4.0, 4.0, 0.01)
Y = np.arange(-4.0, 4.0, 0.01)
Z = np.zeros((len(X), len(Y)))
xmesh, ymesh = np.meshgrid(X, Y)
for x_idx in range(len(X)):
    for y_idx in range(len(Y)):
        z = objective(X[x_idx], Y[y_idx])
        Z[x_idx, y_idx] = z
ax.contour(xmesh, ymesh, Z.T, levels=10, linewidths=3, zorder=0)

# # plot constraint 1
ax.plot(X, X/2.0, 'grey', linestyle='dashed', linewidth=2, zorder=5)

# # plot constraint 2
ax.plot(np.zeros_like(Y), Y, 'grey', linestyle='dashed', linewidth=2, zorder=5)

# # plot solutions
# ax.scatter([s[0] for s in solution_list], [s[1] for s in solution_list], c='black', s=20, marker='x')
for s_idx in range(len(solution_list[:-1])):
    xyA = solution_list[s_idx]
    xyB = solution_list[s_idx+1]
    coordsA = "data"
    coordsB = "data"
    color = "red"
    con = ConnectionPatch(xyA, xyB, coordsA, coordsB,
                        arrowstyle="-|>", shrinkA=8, shrinkB=5,
                        mutation_scale=30, fc=color, color=color, linewidth=3, zorder=10)
    ax.plot([xyA[0], xyB[0]], [xyA[1], xyB[1]], "o", c=color, markersize=5, alpha=0.5, zorder=10)
    ax.add_artist(con)

for s_idx in range(len(solution_list2[:-1])):
    xyA = solution_list2[s_idx]
    xyB = solution_list2[s_idx+1]
    coordsA = "data"
    coordsB = "data"
    color = "orange"
    con = ConnectionPatch(xyA, xyB, coordsA, coordsB,
                        arrowstyle="-|>", shrinkA=8, shrinkB=5,
                        mutation_scale=30, fc=color, color=color, linewidth=3, zorder=10)
    ax.plot([xyA[0], xyB[0]], [xyA[1], xyB[1]], "o", c=color, markersize=5, alpha=0.5, zorder=10)
    ax.add_artist(con)

for s_idx in range(len(solution_list3[:-1])):
    xyA = solution_list3[s_idx]
    xyB = solution_list3[s_idx+1]
    coordsA = "data"
    coordsB = "data"
    color = "blue"
    con = ConnectionPatch(xyA, xyB, coordsA, coordsB,
                        arrowstyle="-|>", shrinkA=8, shrinkB=5,
                        mutation_scale=30, fc=color, color=color, linewidth=3, zorder=10)
    ax.plot([xyA[0], xyB[0]], [xyA[1], xyB[1]], "o", c=color, markersize=5, alpha=0.5, zorder=10)
    ax.add_artist(con)

plt.tight_layout()
plt.savefig("result.png")