import gymnasium as gym
import numpy as np
#from puckworld import PuckWorldEnv
from agents import DQNAgent3
from utils import learning_curve
from gymnasium.envs.registration import register
from mazewrapper import ConcatenateObservationNoGoal
import random
from approximator import NetApproximator
import torch
import matplotlib.pyplot as plt
from sklearn.neighbors import KernelDensity
from sklearn.mixture import GaussianMixture
import datetime

random.seed(123)
np.random.seed(123)

# environment with constraint
register(
    id=f"Maze2d_simple",
    entry_point="gymnasium_robotics.envs.maze.point_maze:PointMazeEnv",
    kwargs={
        "continuing_task":False, #If set to True the episode won’t be terminated when reaching the goal, instead a new goal location will be generated. If False the environment is terminated when the ball reaches the final goal.
        "reward_type":"sparse",
        "maze_map": [           #1:wall,0:free,g:goal,r:initial position of agent
            [1, 1, 1, 1, 1],
            [1, "g", 0, 0, 1],
            [1, 1, 0, 0, 1],
            [1, "r", 0, 0, 1],
            [1, 1, 1, 1, 1],
        ],
    },
    additional_wrappers=(ConcatenateObservationNoGoal.wrapper_spec(),),
    max_episode_steps=10000,
)

# environment without constraint
register(
    id=f"Maze2d_simple0",
    entry_point="gymnasium_robotics.envs.maze.point_maze:PointMazeEnv",
    kwargs={
        "continuing_task":False, #If set to True the episode won’t be terminated when reaching the goal, instead a new goal location will be generated. If False the environment is terminated when the ball reaches the final goal.
        "reward_type":"sparse",
        "maze_map": [           #1:wall,0:free,g:goal,r:initial position of agent
            [1, 1, 1, 1, 1],
            [1, "g", 0, 0, 1],
            [1, 0, 0, 0, 1],
            [1, "r", 0, 0, 1],
            [1, 1, 1, 1, 1],
        ],
    },
    additional_wrappers=(ConcatenateObservationNoGoal.wrapper_spec(),),
    max_episode_steps=500,
)

# load trained DQN
DQN = NetApproximator(4,8,32)
DQN.load_state_dict(torch.load('DQN.pth'))
DQN0 = NetApproximator(4,8,32)
DQN0.load_state_dict(torch.load('DQN0.pth'))

for i in [20,50,100,200,500,1000]:
    # explore in environment without constraint
    env_id = "Maze2d_simple0"
    env0 = gym.make(env_id)
    env0.reset(seed=123)
    reset_return = env0.reset(seed=123)
    base_x = reset_return[0][0]
    base_y = reset_return[0][1]
    print(base_x, base_y)

    # initialize agent for policy-constrained exploration
    agent = DQNAgent3(env0)
    data = agent.learning(gamma=0.99,          # 衰减引子
               epsilon = 1,
               decaying_epsilon = True,
               alpha = 4e-3,#1e-3 
               max_episode_num = i, 
               display = False, DQN=DQN, DQN0=DQN0)
    plt.figure()
    x = data[4]
    y = data[5]
    # 绘制散点图
    plt.xlim(-3,3)
    plt.ylim(-3,3)
    plt.scatter(x, y)

    # 添加标题和坐标轴标签
    plt.title("Iteration-{0}".format(i))
    plt.xlabel("X")
    plt.ylabel("Y")
    current_time = datetime.datetime.now()
    formatted_time = current_time.strftime("%Y-%m-%d_%H-%M-%S") 
    plt.savefig('./inferred_constraints/{0}_{1}.png'.format(formatted_time,i))
    #plt.show()
#input('11111111111111')



iteration = data[2]
reward = data[1]
plt.xlabel('Iteration')
plt.ylabel('Episodic Rewards')
plt.plot(iteration, reward)
plt.savefig('./inferred_constraints/reward_{0}.png'.format(formatted_time))
plt.show()

cost = data[3]
plt.xlabel('Iteration')
plt.ylabel('Episodic Costs')
plt.plot(iteration, cost)
plt.savefig('./inferred_constraints/cost_{0}.png'.format(formatted_time))
plt.show()

input('finish')


