import math
from collections import OrderedDict

import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dm_env import specs

import utils
from agent.ddpg import DDPGAgent
from agent.diayn_modules import MULTI_DIAYN, MULTI_TRANS_DIAYN, PARTED_DIAYN, PARTED_TRANS_DIAYN, PARTED_ANTI_DIAYN
from agent.diayn_actors import SkillActor, MCPActor, SeparateSkillActor, IndepActor
from agent.networks.attention_policy import AttnPolicy
from agent.multi_diayn import MULTI_DIAYNAgent
from agent.partition_utils import get_domain_stats, observation_filter, get_gc_features, get_gc_stats


class GC_Wipe_Discriminator(nn.Module):
	def __init__(self, domain, skill_dim, skill_channel, env_config):
		super().__init__()

		self.domain = domain
		self.env_config = env_config
		self.skill_dim = skill_dim

	def forward(self, obs, next_obs):
		next_obs = get_gc_features(next_obs, self.domain, self.env_config)
		return next_obs


class GC_DIAYN_Agent(MULTI_DIAYNAgent):
	def __init__(self, return_dist, **kwargs):
		super().__init__(**kwargs)

		# We can just pass in diayn args together
		gc_diayn_args = [self.domain, self.skill_dim, self.gc_skill_channel, self.env_config]

		self.gc_diayn = GC_Wipe_Discriminator(*gc_diayn_args)  # to(self.device)
		self.return_dist = return_dist

	# This function will be called in multi_diayn init
	def init_params(self):
		_, partitions = get_domain_stats(self.domain, self.env_config)
		self.diayn_skill_channel = len(partitions) - 1
		self.gc_skill_channel = get_gc_stats(self.domain, self.env_config)
		self.skill_channel = self.diayn_skill_channel + self.gc_skill_channel

	def update_diayn(self, skill, obs, next_obs):
		# skill shape: (bs, skill_channel * skill_dim)

		gc_skill = skill[:, :self.gc_skill_channel * self.skill_dim]
		diayn_skill = skill[:, self.gc_skill_channel * self.skill_dim:]

		metrics = super().update_diayn(diayn_skill, obs, next_obs)

		with torch.no_grad():
			gc_loss, gc_acc = self.compute_gc_diayn_loss(obs, next_obs, gc_skill)

		if self.use_tb or self.use_wandb:
			metrics['gc_loss'] = gc_loss.item()
			for idx, acc in enumerate(gc_acc):
				metrics['gc_acc_{}'.format(idx)] = acc.item()
			metrics['gc_acc_avg'] = gc_acc.mean().item()

		return metrics

	def compute_intr_reward(self, skill, obs, next_obs):
		gc_skill = skill[:, :self.gc_skill_channel * self.skill_dim]
		diayn_skill = skill[:, self.gc_skill_channel * self.skill_dim:]

		diayn_reward = super().compute_intr_reward(diayn_skill, obs, next_obs)

		gc_skill = gc_skill.reshape(-1, self.skill_dim)  # (bs * channel) * dim
		z_hat = torch.argmax(gc_skill, dim=-1)

		if self.return_dist:
			d_pred = self.gc_diayn(obs, next_obs).reshape(-1, self.skill_dim)
			# d_pred_log_softmax = F.log_softmax(d_pred, dim=-1)
			# _, pred_z = torch.max(d_pred_log_softmax, dim=-1, keepdim=True)

			reward = -1 * d_pred[torch.arange(d_pred.shape[0]), z_hat] # r = -dist
		else:
			# diayn style reward
			d_pred = self.gc_diayn(obs, next_obs).reshape(-1, self.skill_dim)
			d_pred_log_softmax = F.log_softmax(d_pred, dim=-1)
			_, pred_z = torch.max(d_pred_log_softmax, dim=-1, keepdim=True)

			reward = d_pred_log_softmax[torch.arange(d_pred.shape[0]), z_hat] - math.log(1 / self.skill_dim)

		if self.monolithic_Q:
			# reward is the mean over skill channels
			reward = reward.reshape(-1, self.gc_skill_channel).sum(dim=-1, keepdims=True)
		else:
			reward = reward.reshape(-1, self.gc_skill_channel)

		reward = reward * self.diayn_scale

		if self.monolithic_Q:
			reward = reward + diayn_reward
		else:
			if self.diayn_skill_channel > 0:
				# No need for concate if we are only using gc
				reward = torch.cat([reward, diayn_reward], dim=-1)
		return reward

	def compute_gc_diayn_loss(self, state, next_state, skill):
		"""
		DF Loss
		"""
		# We merge the skill channel and the batch dim when computing loss

		skill = skill.reshape(-1, self.skill_dim)
		z_hat = torch.argmax(skill, dim=-1)
		d_pred = self.gc_diayn(state, next_state).reshape(-1, self.skill_dim)

		if self.return_dist:
			_, pred_z = torch.min(d_pred, dim=-1, keepdim=True)

			d_loss = self.diayn_criterion(d_pred, z_hat) * self.gc_skill_channel  # to maintain the original lr
			# TODO: d_loss Probably not useful
		else:
			d_pred_log_softmax = F.log_softmax(d_pred, dim=-1)
			_, pred_z = torch.max(d_pred_log_softmax, dim=-1, keepdim=True)

			d_loss = self.diayn_criterion(d_pred, z_hat) * self.gc_skill_channel  # to maintain the original lr

		acc_list = torch.eq(
			z_hat, pred_z.reshape(1, list(pred_z.size())[0])[0]
		).reshape(-1, self.gc_skill_channel)
		df_accuracy = torch.sum(acc_list, dim=0).float() / acc_list.shape[0]

		return d_loss, df_accuracy
