from maddpg import *
from conjugategrad import *
from spsa import *
import numpy as np
from numpy import linalg as la
from scipy.linalg import sqrtm


def feature_expectation_2(trajectories):
    fe = np.zeros(2)
    target = np.array([[0, 0.5], [-0.5, 0]])
    for i in range(len(trajectories)):
        trajectory = trajectories[i]
        for j in range(len(trajectory)):
            (state, action) = trajectory[j]
            dis = ((state[0][0] - state[1][0]) ** 2 + (state[0][1] - state[1][1]) ** 2) ** 0.5
            feature = np.array(
                [-((state[1][0] - target[1][0]) ** 2 + (state[1][1] - target[1][1]) ** 2) ** 0.5,
                 0])
            if dis <= 0.1:
                feature[1] = -10 * np.exp(5 * -dis)
            fe += feature
    fe = fe / len(trajectories)
    return fe


def feature_expectation_1(trajectories):
    fe = np.zeros(2)
    target = np.array([[0, 0.5], [-0.5, 0]])
    for i in range(len(trajectories)):
        trajectory = trajectories[i]
        for j in range(len(trajectory)):
            (state, action) = trajectory[j]
            dis = ((state[0][0] - state[1][0]) ** 2 + (state[0][1] - state[1][1]) ** 2) ** 0.5
            speed = (action[1][0] ** 2 + action[1][1] ** 2) ** 0.5 - (action[0][0] ** 2 + action[0][1] ** 2) ** 0.5
            dis_g = ((state[0][0] - target[0][0]) ** 2 + (state[0][1] - target[0][1]) ** 2) ** 0.5
            feature = np.array([-np.exp(-dis) * dis_g, -np.exp(-speed-0.5) * dis_g])
            if dis <= 0.15:
                feature[0] = -10 * np.exp(5 * -dis)
            fe += feature
    fe = fe / len(trajectories)
    return fe


def f_value(trajectories, weight):
    f = np.zeros(1)
    target = np.array([[0, 0.5], [-0.5, 0]])
    for i in range(len(trajectories)):
        trajectory = trajectories[i]
        for j in range(len(trajectory)):
            (state, action) = trajectory[j]
            dis = ((state[0][0] - state[1][0]) ** 2 + (state[0][1] - state[1][1]) ** 2) ** 0.5
            speed = (action[1][0] ** 2 + action[1][1] ** 2) ** 0.5 - (action[0][0] ** 2 + action[0][1] ** 2) ** 0.5
            dis_g = ((state[0][0] - target[0][0]) ** 2 + (state[0][1] - target[0][1]) ** 2) ** 0.5
            feature = np.array([-np.exp(-dis) * dis_g, -np.exp(-speed-0.5) * dis_g])
            if dis <= 0.15:
                feature[0] = -10 * np.exp(5 * -dis)
            f += weight.T @ feature - dis_g
    f = f / len(trajectories)
    return f


if __name__ == '__main__':
    theta_1 = np.array([-0.95, -0.95])
    theta_2 = np.array([-0.95, -0.95])
    target = np.array([[0, 0.5], [-0.5, 0]])
    theta_1_gt = np.array([1, 1])
    reg = 1e-3
    alpha = 0.0001/4
    beta = 0.0004
    theta_1_list = []
    theta_2_list = []
    c = 0.05
    for i in range(100):
        fe_2e = np.ones(2) * -100
        while fe_2e[0] < -50 or fe_2e[1] < -50:
            trajectories_pmg = train(theta_1, theta_2, False, True, np.zeros(2), np.zeros(2), False, i)
            fe_2e = feature_expectation_2(trajectories_pmg)
            print('fe_2e:', fe_2e)
        for j in range(int(np.ceil((i + 1) ** 0.25 / 2))):
            trajectories_smg = train(theta_1, theta_2, False, False, np.zeros(2), np.zeros(2), False, i)
            fe_2 = feature_expectation_2(trajectories_smg)
            print('fe_2:', fe_2)
            g_l = fe_2 - fe_2e + reg * theta_2
            print('g_l:', g_l)
            theta_2 -= beta * g_l
            theta_2[1] -= 10 * beta * g_l[1]
            for k in range(theta_2.shape[0]):
                if theta_2[k] > 1:
                    theta_2[k] = 1
                if theta_2[k] < -1:
                    theta_2[k] = -1
            print('theta_2:', theta_2)
        g_f = np.ones(2) * 10001
        while np.abs(g_f[0]) > 5000 or np.abs(g_f[1]) > 5000:
            delta_1 = np.random.choice([-1, 1], size=theta_1.shape[0])
            delta_2 = np.random.choice([-1, 1], size=theta_2.shape[0])

            trajectories_fb_1p = train(theta_1, theta_2, False, False, c * delta_1, np.zeros(2), False, i)
            trajectories_fb_1m = train(theta_1, theta_2, False, False, -c * delta_1, np.zeros(2), False, i)
            trajectories_fb_2p = train(theta_1, theta_2, False, False, np.zeros(2), c * delta_2, False, i)
            trajectories_fb_2m = train(theta_1, theta_2, False, False, np.zeros(2), -c * delta_2, False, i)
            f_1p = f_value(trajectories_fb_1p, theta_1_gt)
            f_1m = f_value(trajectories_fb_1m, theta_1_gt)
            f_2p = f_value(trajectories_fb_2p, theta_1_gt)
            f_2m = f_value(trajectories_fb_2m, theta_1_gt)
            fe_1p = feature_expectation_1(trajectories_fb_2p)
            fe_1m = feature_expectation_1(trajectories_fb_2m)
            fe_2p = feature_expectation_2(trajectories_fb_2p)
            fe_2m = feature_expectation_2(trajectories_fb_2m)
            print('f_1p:', f_1p)
            print('f_1m:', f_1m)
            print('f_2p:', f_2p)
            print('f_2m:', f_2m)
            print('fe_1p:', fe_1p)
            print('fe_1m:', fe_1m)
            print('fe_2p:', fe_2p)
            print('fe_2m:', fe_2m)
            gf1 = spsa(f_1p,
                       f_1m, delta_1, c).reshape(2)
            gf2 = spsa(f_2p,
                       f_2m, delta_2, c).reshape(2)
            hl12 = spsa(fe_1p, fe_1m, delta_2, c)
            g22 = spsa(fe_2p + reg * (theta_2 + c * delta_2),
                       fe_2m + reg * (theta_2 - c * delta_2), delta_2, c)
            hl22 = (g22 + g22.T) / 2
            print('gf1:', gf1, 'gf2:', gf2, 'hl12:', hl12, 'hl22:', hl22)
            hl22pd = sqrtm(hl22 @ hl22 + 0.00001 * np.eye(2))
            print('eigenval：', np.linalg.eigvals(hl22pd))
            # pd = NPD.nearestPD(hl22)
            # print('pd:', pd)
            inv = conjugate(hl22pd, gf2)
            print('inv:', inv)
            g_f = gf1 - hl12 @ inv
            print('g_f:', g_f)
        theta_1 -= alpha * g_f
        for m in range(theta_1.shape[0]):
            if theta_1[m] > 1:
                theta_1[m] = 1
            if theta_1[m] < -1:
                theta_1[m] = -1
        print(i, 'theta_1:', theta_1)
        theta_1_list.append(theta_1.copy())
        theta_2_list.append(theta_2.copy())
        # alpha = 0.0001/5* 0.8**((1+i)/10)
        c = 0.05* 0.8**((1+i)/10)
        # train(theta_1, theta_2, False, False, np.zeros(2), np.zeros(2), True, i)
    with open('theta_1.txt', 'w') as file:
        for number in theta_1_list:
            file.write(f"{number.copy()}\n")
    with open('theta_2.txt', 'w') as file:
        for number in theta_2_list:
            file.write(f"{number.copy()}\n")
    result = train(theta_1, theta_2, False, False, np.zeros(2), np.zeros(2), False, 0)
    collision = []
    dones = []
    for i in range(len(result)):
        learner_x = []
        learner_y = []
        expert_x = []
        expert_y = []
        collision_count = 0
        done = 0
        for j in range(40):
            (state, action) = result[i][j]
            learner_x.append(state[0][0])
            learner_y.append(state[0][1])
            expert_x.append(state[1][0])
            expert_y.append(state[1][1])
            if ((state[0][0] - state[1][0]) ** 2 + (state[0][1] - state[1][1]) ** 2) ** 0.5 <= 0.1:
                collision_count = 1
        collision.append(collision_count)
        if ((state[0][0] - target[0][0]) ** 2 + (state[0][1] - target[0][1]) ** 2) ** 0.5 <= 0.05 and (
                (state[1][0] - target[1][0]) ** 2 + (state[1][1] - target[1][1]) ** 2) ** 0.5 <= 0.05:
            done = 1
        dones.append(done)
        plt.plot(learner_x, learner_y, '+')
        plt.plot(expert_x, expert_y, 'x')
        theta = np.linspace(0, 2 * np.pi, 100)
        plt.plot(0 + 0.05 * np.cos(theta), 0.5 + 0.05 * np.sin(theta))
        plt.plot(-0.5 + 0.05 * np.cos(theta), 0 + 0.05 * np.sin(theta))
        plt.show()
    print(dones, collision)
