import numpy as np
import matplotlib.pyplot as plt
import random


M = 2
R = 1
N = 2
P = 1
P_prime = 2
lam=0.1

# 定义奖励函数
def reward_func(x, a):
    return -(0.5 * M * x**2 + R * x * a + 0.5 * N * a**2 + P * x + P_prime * a)
# 奖励导数
def reward_grad(x, a):
    return -(R * x + N *a + P_prime)

def Q(theta, x, a):
    return 0.5 * theta[0] * x ** 2 + theta[1] * x + 0.5 * theta[2] * a ** 2 + theta[3] * a + theta[4] * x * a + theta[5]

def tes_func(x, a):
    return np.array([0.5 * x ** 2, x, a ** 2, a, x * a, 1])

def score(v, x, a):
    return -np.exp(v[0]) * a + v[1] * x + v[2]

def grad_Q(theta, x, a):
    return theta[2] * a * (1/lam) + theta[4]* x* (1/lam)+theta[3]* (1/lam)

def grad_score(v,x,a):
    return np.array([-np.exp(v[0]) *a, x, 1])

def cqsm(seed=0, T=1e4, dt=0.1, A=-1, B=0, C=0, D=1, rho=1, alpha_theta=0.01, alpha_v=0.01):
    random.seed(seed)
    np.random.seed(seed)  # 设置随机种子以确保可重复性
    # 设置参数
    nt = int(T / dt)
    x0 = 0
    # theta = np.zeros(6, dtype=np.float32)
    # v = np.zeros(3, dtype=np.float32)
    theta_path = np.zeros((6, nt + 1), dtype=np.float32)
    v_path = np.zeros((3, nt + 1), dtype=np.float32)
    # v=[1.52913155, -1.5119060, -3.5624157]
    # v=[0, 0, 0]
    v=[random.uniform(0, 1) for _ in range(3)]
    # theta=[-0.59047134,-0.23069812, -0.46141679, -0.35624157, -0.15119060, 0.17312350]
    theta=[0,0, 0, 0, 0, 0]
    #theta=[random.uniform(-1, 1) for _ in range(6)]
    # 初始化路径
    # x_path = x0 * np.ones(nt + 1, dtype=np.float32)
    # a_path = np.zeros(nt + 1, dtype=np.float32)
    theta_path[:, 0] = theta
    v_path[:, 0] = v
    r_path=[]

    # 生成布朗运动增量
    dw = np.sqrt(dt) * np.random.randn(nt)

    average_reward = []
    cur_reward_sum=0
    for t in range(nt):
        # 选择动作,均值为(v[1] * x0 + v[2] )/np.exp(v[0]), 标准差为np.exp(-v[0] / 2)
        a_0 = (v[1] * x0 + v[2] )/np.exp(v[0])+ np.exp(-v[0] / 2) * np.random.randn()

        # 更新状态
        x_now = x0 + (A * x0 + B * a_0) * dt + (C * x0 + D * a_0) * dw[t]

        # 更新动作
        a_now =(v[1] * x_now + v[2] ) / np.exp(v[0])+ np.exp(-v[0] / 2) * np.random.randn()

        # 计算奖励
        r_now = reward_func(x0, a_0)
        r_path.append(r_now)
        cur_reward_sum += r_now
        average_reward.append(cur_reward_sum / (t + 1))

        delta = Q(theta, x_now, a_now) - Q(theta, x0, a_0) + dt * (r_now - 0.5 * lam * score(v, x0, a_0) ** 2 - rho * Q(theta, x0, a_0))
        delta_theta = delta * tes_func(x0, a_0)
        delta_v = (grad_Q(theta, x0, a_0) - score(v, x0, a_0)) * grad_score(v,x0, a_0)

        theta = theta + alpha_theta * delta_theta / max([np.log((t+1) * dt), 1])

        v = v+ alpha_v * delta_v / max([np.log((t+1) * dt), 1])

        x0 = x_now
        theta_path[:, t + 1] = theta
        v_path[:, t + 1] = v

    return average_reward, theta_path, v_path


T=1e4
dt=0.1
nt = int(T / dt)

average_reward_list = []
theta_path_list = []
v_path_list = []
for i in range(5):
    seed = i
    average_reward, theta_path, v_path = cqsm(seed=seed, T=T, dt=dt, A=-1, B=0, C=0, D=1, rho=1, alpha_theta=0.01, alpha_v=0.01)
    average_reward_list.append(average_reward)
    theta_path_list.append(theta_path)
    v_path_list.append(v_path)





average_reward_mean, average_reward_std = np.mean(average_reward_list, axis=0), np.std(average_reward_list, axis=0)
theta_path_mean, theta_path_std = np.mean(theta_path_list, axis=0), np.std(theta_path_list, axis=0)
v_path_mean, v_path_std = np.mean(v_path_list, axis=0), np.std(v_path_list, axis=0)


v_true=[1.52913155, -1.5119060, -3.5624157]
theta_true=[-0.59047134,-0.23069812, -0.46141679, -0.35624157, -0.15119060, 0.17312350]

plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 18  # 设置字体大小
# 设置全局字体属性
plt.rcParams['font.weight'] = 'bold'


plt.figure()
plt.plot(np.arange(nt + 1) * dt, theta_path_mean[0], '#427AB2', linewidth=2, label=r'$\theta_0$ Path')
plt.fill_between(np.arange(nt + 1) * dt, theta_path_mean[0] - theta_path_std[0], theta_path_mean[0] + theta_path_std[0], color='#427AB2', alpha=0.3)
plt.plot(np.arange(nt + 1) * dt, theta_true[0] * np.ones(nt + 1), '#427AB2', linestyle='--', linewidth=2, label=r'True Value $\theta_0$')
plt.plot(np.arange(nt + 1) * dt, theta_path_mean[1], '#F09148', linewidth=2, label=r'$\theta_1$ Path')
plt.fill_between(np.arange(nt + 1) * dt, theta_path_mean[1] - theta_path_std[1], theta_path_mean[1] + theta_path_std[1], color='#F09148', alpha=0.3)
plt.plot(np.arange(nt + 1) * dt, theta_true[1] * np.ones(nt + 1), '#F09148', linestyle='--', linewidth=2, label=r'True Value $\theta_1$')
plt.plot(np.arange(nt + 1) * dt, theta_path[2], '#DBDB8D', linewidth=2, label=r'$\theta_2$ Path')
plt.fill_between(np.arange(nt + 1) * dt, theta_path[2] - theta_path_std[2], theta_path[2] + theta_path_std[2], color='#DBDB8D', alpha=0.3)
plt.plot(np.arange(nt + 1) * dt, theta_true[2] * np.ones(nt + 1), '#DBDB8D', linestyle='--', linewidth=2, label=r'True Value $\theta_2$')
plt.legend(fontsize=10)
plt.xlabel('Time', fontsize=18,fontweight='bold')
plt.show()

plt.figure()

plt.plot(np.arange(nt + 1) * dt, theta_path[3], '#C59D94', linewidth=2, label=r'$\theta_3$ Path')
plt.fill_between(np.arange(nt + 1) * dt, theta_path[3] - theta_path_std[3], theta_path[3] + theta_path_std[3], color='#C59D94', alpha=0.3)
plt.plot(np.arange(nt + 1) * dt, theta_true[3] * np.ones(nt + 1), '#C59D94', linestyle='--', linewidth=2, label=r'True Value $\theta_3$')
plt.plot(np.arange(nt + 1) * dt, theta_path[4], '#AFC7E8', linewidth=2, label=r'$\theta_4$ Path')
plt.fill_between(np.arange(nt + 1) * dt, theta_path[4] - theta_path_std[4], theta_path[4] + theta_path_std[4], color='#AFC7E8', alpha=0.3)
plt.plot(np.arange(nt + 1) * dt, theta_true[4] * np.ones(nt + 1), '#AFC7E8', linestyle='--', linewidth=2, label=r'True Value $\theta_4$')
plt.plot(np.arange(nt + 1) * dt, theta_path_mean[5], '#FF9896', linewidth=2, label=r'$\theta_5$ Path')
plt.fill_between(np.arange(nt + 1) * dt, theta_path_mean[5] - theta_path_std[5], theta_path_mean[5] + theta_path_std[5], color='#FF9896', alpha=0.3)
plt.plot(np.arange(nt + 1) * dt, theta_true[5] * np.ones(nt + 1), '#FF9896', linestyle='--', linewidth=2, label=r'True Value $\theta_5$')
plt.legend(fontsize=10)
plt.xlabel('Time', fontsize=18,fontweight='bold')
plt.show()

plt.figure()
plt.plot(np.arange(1, nt + 1) * dt, average_reward_mean, 'r', linewidth=2, label='CQSM')
plt.fill_between(np.arange(1, nt + 1) * dt, average_reward_mean - average_reward_std, average_reward_mean + average_reward_std, color='r', alpha=0.3)
plt.legend(fontsize=10)
plt.xlabel('Time', fontsize=18,fontweight='bold')
plt.ylabel('Average Reward', fontsize=18,fontweight='bold')
plt.show()

plt.plot(np.arange(nt + 1) * dt, v_path[0], '#43978F', linewidth=2, label=r'$v_0$ Path')
plt.fill_between(np.arange(nt + 1) * dt, v_path[0] - v_path_std[0], v_path[0] + v_path_std[0], color='#43978F', alpha=0.3)
plt.plot(np.arange(nt + 1) * dt, v_true[0] * np.ones(nt + 1), '#43978F', linestyle='--', linewidth=2, label=r'True Value $v_0$')
plt.plot(np.arange(nt + 1) * dt, v_path[1], '#ABD0F1', linewidth=2, label=r'$v_1$ Path')
plt.fill_between(np.arange(nt + 1) * dt, v_path[1] - v_path_std[1], v_path[1] + v_path_std[1], color='#ABD0F1', alpha=0.3)
plt.plot(np.arange(nt + 1) * dt, v_true[1] * np.ones(nt + 1), '#ABD0F1', linestyle='--', linewidth=2, label=r'True Value $v_1$')
plt.plot(np.arange(nt + 1) * dt, v_path[2], '#E56F5E', linewidth=2, label=r'$v_2$ Path')
plt.fill_between(np.arange(nt + 1) * dt, v_path[2] - v_path_std[2], v_path[2] + v_path_std[2], color='#E56F5E', alpha=0.3)
plt.plot(np.arange(nt + 1) * dt, v_true[2] * np.ones(nt + 1), '#E56F5E', linestyle='--', linewidth=2, label=r'True Value $v_2$')
plt.legend(fontsize=10)
plt.xlabel('Time', fontsize=18,fontweight='bold')
plt.show()