from copy import deepcopy
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
os.environ['MUJOCO_GL'] = 'egl'
import torch
import numpy as np
import gym
gym.logger.set_level(40)
import time
import random
from pathlib import Path
from cfg import parse_cfg
from src.env import make_env
from algorithm.tdmpc import KETDMPC
from algorithm.helper import Episode, ReplayBuffer, PseudoCounts, SalientExperienceReplay, ContPseudoCounts
import logger
import torch
from algorithm.helper import plot_heatmaps, save_heatmaps
import time

torch.backends.cudnn.benchmark = True
__CONFIG__, __LOGS__ = 'cfgs', 'logs'


def set_seed(seed):
	random.seed(seed)
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed_all(seed)


def evaluate(buffer, env, agent, pre_rollout_agent, pre_rollout_steps, pseudo_counts, num_episodes, step, env_step, video, max_steps):
	"""Evaluate a trained agent and optionally save a video."""
	episode_rewards = []
	for i in range(num_episodes):
		obs, terminated, truncated, ep_reward, t = env.reset(), False, False, 0, 0
		if pre_rollout_agent is not None:
			for _ in range(pre_rollout_steps):
				action = pre_rollout_agent.learner.plan(buffer, pseudo_counts, obs, eval_mode=True, step=step, t0=True, key='learner')
				obs, _, _, _, _ = env.step(action.cpu().numpy())
		if video: video.init(env, enabled=(i==0))
		obses = [obs]
		while len(obses) < max_steps:
			action = agent.learner.plan(buffer, pseudo_counts, obs, eval_mode=True, step=step, t0=t==0, key='learner')
			obses.append(torch.cat([torch.tensor(obs, dtype=torch.float32).cuda(), torch.tensor(action, dtype=torch.float32).cuda()]))
			next_obs, reward, terminated, truncated, _ = env.step(action.cpu().numpy())
			obs = next_obs
			ep_reward += reward
			if video: video.record(env)
			t += 1
		assert len(obses) == max_steps
		episode_rewards.append(ep_reward)
	if video: video.save(env_step)
	total_reward = np.nanmean(episode_rewards)
	return total_reward

def evaluate_reviewer(buffer, env, agent, pre_rollout_agent, pseudo_counts, num_episodes, step, env_step, video):
	"""Evaluate a trained agent and optionally save a video."""
	episode_rewards = []
	for _ in range(1):
		obs, terminated, truncated, ep_reward, t = env.reset(), False, False, 0, 0
		if pre_rollout_agent is not None:
			for _ in range(10):
				action = pre_rollout_agent.learner.plan(buffer, pseudo_counts, obs, eval_mode=True, step=step, t0=True, key='learner')
				obs, _, _, _, _ = env.step(action.cpu().numpy())
		while not truncated and not terminated:
			action = agent.reviewer.plan(buffer, pseudo_counts, obs, eval_mode=True, step=step, t0=t==0, key='learner')
			next_obs, reward, terminated, truncated, _ = env.step(action.cpu().numpy())
			reward, _, _ = agent.reviewer_reward(
					torch.tensor([obs], dtype=torch.float32).cuda(),  
					torch.tensor([next_obs], dtype=torch.float32).cuda(),
					action.unsqueeze(0).cuda()
				)
			reward *= pseudo_counts.get_intrinsic_rewards(agent.reviewer.model._encoder(
				torch.tensor([obs], dtype=torch.float32).cuda()))
			reward = reward.detach().cpu()
			obs = next_obs
			ep_reward += reward[0][0]
			t += 1
		episode_rewards.append(ep_reward)
		print(ep_reward)

def merge_buffer(buffer1,  buffer2):
	if buffer1 is None:
		return buffer2
	elif buffer2 is None:
		return buffer1
	buffer3 = deepcopy(buffer2)
	buffer3.capacity = buffer1.capacity + buffer3.capacity
	buffer3.cur_idx += buffer1.capacity
	buffer3.indices = list(map(lambda x: x + len(buffer1._obs), buffer2.indices))
	buffer3.indices = buffer1.indices + buffer2.indices
	buffer3._obs = torch.cat((buffer1._obs, buffer2._obs), dim=0)
	buffer3._actions = torch.cat((buffer1._actions, buffer2._actions), dim=0)
	return buffer3


def train(cfg):
	"""Training script for TD-MPC. Requires a CUDA-enabled device."""
	# num_room_1 = 1
	cfg.se_buffer_size = cfg.se_buffer_trajectories * (cfg.max_steps + 1)
	assert torch.cuda.is_available()
	set_seed(cfg.seed)
	work_dir = Path().cwd() / __LOGS__ / cfg.task / cfg.modality / cfg.exp_name\
		  / str(cfg.task_idx) / str(cfg.seed)
	fp = cfg.ckpt
	start_idx = cfg.task_idx
	pre_rollout_agent = None
	cfg.time_limit = cfg.max_steps * cfg.action_repeat
	
	for task_idx in range(start_idx, len(cfg.tasks)):
		cfg.task_idx = task_idx
		cfg.horizon = cfg.horizons[task_idx]
		cfg.task = cfg.tasks[task_idx]
		cfg.ckpt = fp
		rollout_steps = cfg.max_steps

		if cfg.ckpt is not None and cfg.use_reviewer:
			rollout_steps = cfg.max_steps * 2

		env, learner_buffer, reviewer_buffer, learner_se_buffer, reviewer_se_buffer,\
		 old_se_buffer = make_env(cfg), ReplayBuffer(cfg), ReplayBuffer(cfg), SalientExperienceReplay(cfg), None, None
		
		if not cfg.use_encoder:
			cfg.latent_dim = cfg.obs_shape[0]

		if cfg.pre_rollout_ckpt is not None:
			pre_rollout_cfg = deepcopy(cfg)
			pre_rollout_cfg.ckpt = cfg.pre_rollout_ckpt
			pre_rollout_cfg.load_policy = True
			pre_rollout_agent = KETDMPC(pre_rollout_cfg)

		if cfg.se_buffer_path is not None:
			old_se_buffer = SalientExperienceReplay(cfg)
			old_se_buffer.load(cfg.se_buffer_path, device=cfg.device)

		if cfg.ckpt is not None and cfg.use_reviewer:
			reviewer_se_buffer = SalientExperienceReplay(cfg)
			
		agent = KETDMPC(cfg)
		pseudo_counts = PseudoCounts(cfg=cfg, k=cfg.k) if cfg.env=='minigrid' else ContPseudoCounts(cfg.latent_dim)

		if cfg.show_plots:
			plot_heatmaps(cfg, agent.learner)
		
		# Run training
		L = logger.Logger(work_dir, cfg)
		episode_idx, start_time = 0, time.time()
		# top_reviewer_rewards = []

		for step in range(0, cfg.train_steps+rollout_steps, rollout_steps):
			if cfg.eval_mode > 0:
				while True:
					evaluate(reviewer_buffer, env, agent, pre_rollout_agent, cfg.pre_rollout_steps, pseudo_counts,\
			   cfg.eval_episodes, cfg.eval_mode, cfg.eval_mode, L.video, cfg.max_steps)
			# Collect trajectory
			obs = env.reset()

			if cfg.pre_rollout_ckpt is not None:
				for _ in range(cfg.pre_rollout_steps):
					action = pre_rollout_agent.learner.plan(learner_buffer, pseudo_counts, obs, eval_mode=True, step=step, t0=True, key='learner')
					obs, _, _, _, _ = env.step(action.cpu().numpy())

			# learner rollout:
			episode = Episode(cfg, obs)
			if cfg.train_learner:
				add_to_se_buffer = False
				while len(episode) < cfg.max_steps:
					action = agent.learner.plan(learner_buffer, pseudo_counts, obs, step=step, t0=episode.first, key='learner')
					obs, reward, terminated, truncated, _ = env.step(action.cpu().numpy())
					episode += (obs, action, reward, terminated, truncated)
				assert len(episode) == cfg.max_steps
				learner_buffer += episode
				if add_to_se_buffer:
					learner_se_buffer += episode
			
			# reviewer rollout
			if cfg.ckpt is not None and cfg.use_reviewer:
				reviewer_episode_reward = 0
				episode_old_reward = 0
				episode_cost = 0
				if cfg.reset_reviewer_env:
					obs = env.reset()
				review_episode = Episode(cfg, obs)
				encoded_obs = agent.reviewer.model._encoder(torch.tensor([obs], dtype=torch.float32).to(cfg.device))
				pseudo_counts.update(encoded_obs)
				add_to_se_buffer = False
				
				while len(review_episode) < cfg.max_steps:
					action = agent.reviewer.plan(reviewer_buffer, pseudo_counts, obs, step=step, t0=review_episode.first, key='reviewer')
					next_obs, _, terminated, truncated, _ = env.step(action.cpu().numpy())
					reward, old_wm_reward, cost = agent.reviewer_reward(
						torch.tensor([obs], dtype=torch.float32).to(cfg.device),  
						torch.tensor([next_obs], dtype=torch.float32).to(cfg.device),
						action.unsqueeze(0).to(cfg.device),
						use_learner_reward = step > cfg.seed_steps and cfg.use_learner_reward)
					obs = next_obs
					encoded_obs = agent.reviewer.model._encoder(torch.tensor([obs], dtype=torch.float32).to(cfg.device))
					pseudo_counts.update(encoded_obs)
					review_episode += (obs, action, reward, terminated, truncated)
					reviewer_episode_reward += reward
					episode_old_reward += old_wm_reward
					episode_cost -= cost

					# if len(top_reviewer_rewards) < cfg.se_buffer_trajectories or\
					# 	reviewer_episode_reward > min(top_reviewer_rewards):
					# 	add_to_se_buffer = True
					# 	top_reviewer_rewards.append(reviewer_episode_reward)
					# 	top_reviewer_rewards = sorted(top_reviewer_rewards, reverse=True)[:10]

				assert len(review_episode) == cfg.max_steps
				reviewer_buffer += review_episode

				if add_to_se_buffer and reviewer_se_buffer is not None:
					reviewer_se_buffer += review_episode

			# Update model
			train_metrics = {}
			if step >= cfg.seed_steps:
				for i in range(cfg.max_steps):
					update_reviewer = cfg.ckpt is not None and cfg.use_reviewer
					train_metrics.update(agent.update(pseudo_counts, learner_buffer, reviewer_buffer, old_se_buffer, 
										step+i, update_learner=cfg.train_learner, update_reviewer=update_reviewer))
			# Log training episode
			episode_idx += 1
			env_step = int(step*cfg.action_repeat)
			common_metrics = {
				'episode': episode_idx,
				'step': step,
				'env_step': env_step,
				'total_time': time.time() - start_time,
				'episode_reward': episode.cumulative_reward,
				'reviewer_episode_reward': reviewer_episode_reward if cfg.use_reviewer and cfg.ckpt is not None else None,
				'old_model_episode_reward': episode_old_reward if cfg.use_reviewer and cfg.ckpt is not None else None,
				'learner_model_episode_cost': episode_cost if cfg.use_reviewer and cfg.ckpt is not None else None,}
			
			train_metrics.update(common_metrics)
			L.log(train_metrics, category='train')

			# Save and evaluate agent periodically
			
			if env_step % cfg.save_freq == 0:
				_model_dir = f'/oscar/data/gdk/wm_finetuning/{cfg.exp_name}/h_{cfg.horizon}/task_{cfg.task_idx}/seed_{cfg.seed}'
				if not os.path.exists(_model_dir):
					os.makedirs(_model_dir)
				if cfg.save_model:
					fp = f'{_model_dir}/model_{env_step}.pt'
					torch.save(agent.state_dict(), fp)
				if cfg.save_buffer:
					buffer_fp = f'{_model_dir}/buffer_{env_step}.pt'
					if old_se_buffer is not None:
						se_buffer = merge_buffer(reviewer_se_buffer, learner_se_buffer)
						merge_buffer(old_se_buffer, se_buffer).save(buffer_fp)
					else:
						learner_se_buffer.save(buffer_fp)
			if env_step % cfg.save_image_freq == 0 and cfg.save_heatmaps:
				if cfg.use_reviewer and cfg.ckpt is not None:
					mse, std = save_heatmaps(cfg, agent.reviewer, env_step)
				else:
					mse, std = save_heatmaps(cfg, agent.learner, env_step)
				common_metrics.update({'mean_wm_score': mse, 'std_wm_score': std})

			if env_step % cfg.eval_freq == 0:
				common_metrics['episode_reward'] = evaluate(
					reviewer_buffer, env, agent, pre_rollout_agent, cfg.pre_rollout_steps, 
					pseudo_counts, cfg.eval_episodes, step, env_step, L.video, cfg.max_steps)
				L.log(common_metrics, category='eval')

		L.finish(agent)
		print('Training completed successfully')


if __name__ == '__main__':
	
	train(parse_cfg(Path().cwd() / __CONFIG__))
