import random
import time
import gym
import gym_minigrid.minigrid as minigrid
import networkx as nx
from networkx import grid_graph
import numpy as np
import torch as th
import gym_minigrid.minigrid as minigrid
import matplotlib.pyplot as plt
import argparse

from envs.multigrid import *
from envs.multigrid.adversarial import *
from envs.box2d import *
from envs.bipedalwalker import *

from envs.multigrid.adversarial import AdversarialEnv, GoalLastVariableBlocksAdversarialEnv
from add.diffusion_human_feedback.predictor import Tutor
from add.diffusion_human_feedback.generator import EnvGenerator
from util import create_parallel_env, DotDict, str2bool, seed

from arguments import parser
import time
import pyvirtualdisplay
if __name__ == '__main__':
	parser.add_argument(
		'--tutor_dir',
		type=str,
		default="../racing/minigrid_60/add/seed_1_cvar_03_guide_5/tutors/",
		help="path to the directory where tutor networks are saved"
	)
	parser.add_argument(
		'--tutor_model',
		type=str,
		default="model_030000.pt",
		help="name of model file"
	)
	args = parser.parse_args()
	args.env_name="CarRacing-Bezier-Adversarial-v0"
	args.num_processes=8
	if args.env_name.startswith("CarRacing"):
		display = pyvirtualdisplay.Display(visible=0, size=(1400, 900), color_depth=24)
		display.start()
	ued_venv, venv = create_parallel_env(args)	
 
	# tutor = Tutor(args)
	# tutor_list = [tutor]
	
	generator = EnvGenerator(args)
	
	# areas = []
	# perimeter = []
	# amplitude = []
	# convex = []
	# notches = []
	# complexity = []

	# x = [2] * 8
	generated_env = np.random.rand(20000,12,2)
	print(np.mean(generated_env))
	print(np.std(generated_env))
 
	generated_env = np.zeros((20000, 12, 2))
	start = time.time()
	for i in range(2500):
		random_envs = generator.generate_random_env(args.num_processes)
		random_envs = np.clip(random_envs, 0, 1)
		generated_env[8*i:8*(i+1)] = random_envs

	# generated_env = np.zeros((800, 8, 1))
	# start = time.time()
	# for i in range(100):
	# 	random_envs = generator.generate_random_env(args.num_processes)
	# 	generated_env[8*i:8*(i+1)] = random_envs
	# 	# venv.reset_to_generated_imgs(random_envs)
	# 	# infos = venv.get_complexity_info()
	# 	# for info in infos:
	# 	# 	areas.append(info['area'])
	# 	# 	perimeter.append(info['perimeter'])
	# 	# 	amplitude.append(info['amplitude'])
	# 	# 	convex.append(info['convex'])
	# 	# 	notches.append(info['notches'])
	# 	# 	complexity.append(info['complexity'])
	# 	end = time.time()
	# 	print(end - start)
	# 	start = end
	print(np.mean(generated_env))
	print(np.std(generated_env))

	# generated_env = np.zeros((800, 8, 1))
	# start = time.time()
	# for i in range(800):
	# 	generated_env[i] = np.random.rand(8,1)
	# print(np.mean(generated_env, axis=0))
	# print(np.std(generated_env, axis=0))
 # print('-' * 20)
	# print("diffusion model result")
	# print("area: ", sum(areas) / len(areas))
	# print("perimeter: ", sum(perimeter) / len(perimeter))
	# print("amplitude: ", sum(amplitude) / len(amplitude))
	# print("convex: ", sum(convex) / len(convex))
	# print("notches: ", sum(notches) / len(notches))
	# print("complexity: ", sum(complexity) / len(complexity))
 
	# areas = []
	# perimeter = []
	# amplitude = []
	# convex = []
	# notches = []
	# complexity = []
	
	# venv.reset_random()
	# for i in range(100):
	# 	venv.reset_random()
	# 	infos = venv.get_complexity_info()
	# 	for info in infos:
	# 		areas.append(info['area'])
	# 		perimeter.append(info['perimeter'])
	# 		amplitude.append(info['amplitude'])
	# 		convex.append(info['convex'])
	# 		notches.append(info['notches'])
	# 		complexity.append(info['complexity'])

	# print('-' * 20)
	# print("reset random result")
	# print("area: ", sum(areas) / len(areas))
	# print("perimeter: ", sum(perimeter) / len(perimeter))
	# print("amplitude: ", sum(amplitude) / len(amplitude))
	# print("convex: ", sum(convex) / len(convex))
	# print("notches: ", sum(notches) / len(notches))
	# print("complexity: ", sum(complexity) / len(complexity))
	
