from copy import deepcopy
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
os.environ['MUJOCO_GL'] = 'egl'
import torch
import numpy as np
import gymnasium as gym
gym.logger.set_level(40)
import random
from pathlib import Path
from cfg import parse_cfg
from src.env import make_env
from algorithm.tdmpc import TDMPC, KETDMPC
from algorithm.helper import Episode, ReplayBuffer, PseudoCounts, SalientExperienceReplay
import logger
import matplotlib.pyplot as plt
import torch
from algorithm.helper import plot_heatmaps, save_heatmaps
from stable_baselines3 import DQN, SAC

torch.backends.cudnn.benchmark = True
__CONFIG__, __LOGS__ = 'cfgs', 'logs'


def set_seed(seed):
	random.seed(seed)
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed_all(seed)


def evaluate(buffer, env, agent, pseudo_counts, num_episodes, step, env_step, video):
	"""Evaluate a trained agent and optionally save a video."""
	episode_rewards = []
	for i in range(num_episodes):
		obs, terminated, truncated, ep_reward, t = env.reset(), False, False, 0, 0
		if video: video.init(env, enabled=(i==0))
		while not truncated and not terminated:
			action = agent.learner.plan(buffer, pseudo_counts, obs, eval_mode=True, step=step, t0=t==0, key='learner')
			next_obs, reward, terminated, truncated, _ = env.step(action.cpu().numpy())
			obs = next_obs
			ep_reward += reward
			if video: video.record(env)
			t += 1
		episode_rewards.append(ep_reward)
		# print(ep_reward)
	if video: video.save(env_step)
	total_reward = np.nanmean(episode_rewards)
	return total_reward

def evaluate_reviewer(buffer, env, agent, pseudo_counts, num_episodes, step, env_step, video):
	"""Evaluate a trained agent and optionally save a video."""
	episode_rewards = []
	for i in range(1):
		obs, terminated, truncated, ep_reward, t = env.reset(), False, False, 0, 0
		while not truncated and not terminated:
			action = agent.reviewer.plan(buffer, pseudo_counts, obs, eval_mode=True, step=step, t0=t==0, key='learner')
			next_obs, reward, terminated, truncated, _ = env.step(action.cpu().numpy())
			reward, _, _ = agent.reviewer_reward(
					torch.tensor([obs], dtype=torch.float32).cuda(),  
					torch.tensor([next_obs], dtype=torch.float32).cuda(),
					action.unsqueeze(0).cuda()
				)
			reward *= pseudo_counts.get_intrinsic_rewards(agent.reviewer.model._encoder(
				torch.tensor([obs], dtype=torch.float32).cuda()))
			reward = reward.detach().cpu()
			obs = next_obs
			ep_reward += reward[0][0]
			t += 1
		episode_rewards.append(ep_reward)
		print(ep_reward)

def merge_buffer(buffer1,  buffer2):
	if buffer1 is None:
		return buffer2
	elif buffer2 is None:
		return buffer1
	buffer3 = deepcopy(buffer2)
	buffer3.capacity = buffer1.capacity + buffer3.capacity
	buffer3.cur_idx += buffer1.capacity
	buffer3.indices = list(map(lambda x: x + len(buffer1._obs), buffer2.indices))
	buffer3.indices = buffer1.indices + buffer2.indices
	buffer3._obs = torch.cat((buffer1._obs, buffer2._obs), dim=0)
	buffer3._actions = torch.cat((buffer1._actions, buffer2._actions), dim=0)
	return buffer3


def train(cfg):
	"""Training script for TD-MPC. Requires a CUDA-enabled device."""
	cfg.se_buffer_size = cfg.se_buffer_trajectories * (cfg.max_steps + 1)
	assert torch.cuda.is_available()
	set_seed(cfg.seed)
	env = make_env(cfg)
	# env = gym.make("CartPole-v1", render_mode="human")
	agent = DQN("MlpPolicy", env, verbose=1)
	agent.learn(total_timesteps=10000000, log_interval=4)


if __name__ == '__main__':
	# multiprocessing.set_start_method('spawn')
	train(parse_cfg(Path().cwd() / __CONFIG__))
