import gymnasium as gym
import numpy as np
#from puckworld import PuckWorldEnv
from agents import DQNAgent4
from utils import learning_curve
from gymnasium.envs.registration import register
from mazewrapper import ConcatenateObservationNoGoal
import random
from approximator import NetApproximator
import torch
import matplotlib.pyplot as plt
from sklearn.neighbors import KernelDensity
from sklearn.mixture import GaussianMixture
import datetime
from matplotlib.patches import Rectangle

random.seed(123)
np.random.seed(123)

# environment with constraint
register(
    id=f"Maze2d_simple",
    entry_point="gymnasium_robotics.envs.maze.point_maze:PointMazeEnv",
    kwargs={
        "continuing_task":False, #If set to True the episode won’t be terminated when reaching the goal, instead a new goal location will be generated. If False the environment is terminated when the ball reaches the final goal.
        "reward_type":"sparse",
        "maze_map": [           #1:wall,0:free,g:goal,r:initial position of agent
            [1, 1, 1, 1, 1],
            [1, "g", 0, 0, 1],
            [1, 1, 0, 0, 1],
            [1, "r", 0, 0, 1],
            [1, 1, 1, 1, 1],
        ],
    },
    additional_wrappers=(ConcatenateObservationNoGoal.wrapper_spec(),),
    max_episode_steps=10000,
)

# environment without constraint
register(
    id=f"Maze2d_simple0",
    entry_point="gymnasium_robotics.envs.maze.point_maze:PointMazeEnv",
    kwargs={
        "continuing_task":False, #If set to True the episode won’t be terminated when reaching the goal, instead a new goal location will be generated. If False the environment is terminated when the ball reaches the final goal.
        "reward_type":"sparse",
        "maze_map": [           #1:wall,0:free,g:goal,r:initial position of agent
            [1, 1, 1, 1, 1],
            [1, "g", 0, 0, 1],
            [1, 0, 0, 0, 1],
            [1, "r", 0, 0, 1],
            [1, 1, 1, 1, 1],
        ],
    },
    additional_wrappers=(ConcatenateObservationNoGoal.wrapper_spec(),),
    max_episode_steps=500,
)

# load trained DQN
DQN = NetApproximator(4,8,32)
DQN.load_state_dict(torch.load('DQN.pth'))
DQN0 = NetApproximator(4,8,32)
DQN0.load_state_dict(torch.load('DQN0.pth'))

for i in [400]:
    # explore in environment without constraint
    env_id = "Maze2d_simple0"
    env0 = gym.make(env_id)
    env0.reset(seed=123)
    reset_return = env0.reset(seed=123)
    base_x = reset_return[0][0]
    base_y = reset_return[0][1]
    print(base_x, base_y)

    # initialize agent for policy-constrained exploration
    agent = DQNAgent4(env0)
    data = agent.learning(gamma=0.99,          # 衰减引子
               epsilon = 1,
               decaying_epsilon = True,
               alpha = 4e-3,#1e-3 
               max_episode_num = i, 
               display = False, DQN=DQN, DQN0=DQN0)
    plt.figure(figsize=(7,7))
    x = data[4]
    y = data[5]
    length_of_costs = data[6]
    # 绘制散点图
    import matplotlib as mpl
    rectangle = Rectangle((-1.5, -0.5), 1, 1, linewidth=2, edgecolor='k', ls='-', facecolor='sandybrown', linestyle='dotted',label='Ground-truth constraint')
    plt.gca().add_patch(rectangle)
    plt.xlim(-2.5,2.5)
    plt.ylim(-2.5,2.5)
    scatter = plt.scatter(x, y, color='#1f77b4',zorder=2,label='Inferred constraint')
    
    # 创建矩形
    

    # 将矩形添加到坐标轴中
    plt.gca().add_patch(rectangle)

    # 添加标题和坐标轴标签
    plt.title("Iteration-{0}".format(i),fontsize=17)
    plt.xlabel("X",fontsize=17)
    plt.ylabel("Y",fontsize=17)
    plt.xticks(fontsize=15)  # 设置x轴刻度字体大小为10
    plt.yticks(fontsize=15)  # 设置y轴刻度字体大小为10
    plt.legend(handles=[rectangle, scatter], labels=['Ground-truth constraint','Inferred constraint'],fontsize=15)
    #plt.legend(fontsize=15)
    current_time = datetime.datetime.now()
    formatted_time = current_time.strftime("%Y-%m-%d_%H-%M-%S") 
    plt.savefig('./inferred_constraints/active_explore_new_{0}_{1}.png'.format(formatted_time,i))
    if i==1000:
        plt.figure(figsize=(7,7))
        rectangle = Rectangle((-1.5, -0.5), 1, 1, linewidth=2, edgecolor='k', ls='-', facecolor='sandybrown', linestyle='dotted',label='Ground-truth constraint')
        plt.gca().add_patch(rectangle)
        plt.xlim(-2.5,2.5)
        plt.ylim(-2.5,2.5)
        plt.scatter(x, y, color='#1f77b4',zorder=2,label='Inferred constraint')
        plt.title("Inferred Constraints".format(i),fontsize=17)
        plt.xlabel("X",fontsize=17)
        plt.ylabel("Y",fontsize=17)
        plt.xticks(fontsize=15)  # 设置x轴刻度字体大小为10
        plt.yticks(fontsize=15)  # 设置y轴刻度字体大小为10
        plt.legend(handles=[rectangle, scatter], labels=['Ground-truth constraint','Inferred constraint'],fontsize=15)
        current_time = datetime.datetime.now()
        formatted_time = current_time.strftime("%Y-%m-%d_%H-%M-%S") 
        plt.savefig('./inferred_constraints/active_explore_new1_{0}_{1}.png'.format(formatted_time,i))
    #plt.show()

print(x)
print(y)
print(length_of_costs)
input('11111111111111')

plt.figure()
iteration = data[2]
reward = data[1]
plt.xlabel('Iteration')
plt.ylabel('Episodic Rewards')
plt.plot(iteration, reward)
plt.savefig('./inferred_constraints/reward_{0}.png'.format(formatted_time))
#plt.show()

plt.figure()
cost = data[3]
print(cost)
input('cost')
plt.xlabel('Iteration')
plt.ylabel('Episodic Costs')
plt.plot(iteration, cost)
plt.savefig('./inferred_constraints/cost_{0}.png'.format(formatted_time))
#plt.show()

input('finish')


