import math
from collections import OrderedDict

import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dm_env import specs

import utils
from agent.ddpg import DDPGAgent
from agent.spectral_utils import spectral_norm


class LSD(nn.Module):
    def __init__(self, obs_dim, skill_dim, hidden_dim):
        super().__init__()
        self.skill_pred_net = nn.Sequential(spectral_norm(nn.Linear(obs_dim, hidden_dim),
                                                          spectral_coef=1.),
                                            nn.ReLU(),
                                            spectral_norm(nn.Linear(hidden_dim, hidden_dim),
                                                          spectral_coef=1.),
                                            nn.ReLU(),
                                            spectral_norm(nn.Linear(hidden_dim, skill_dim),
                                                          spectral_coef=1.))

        self.apply(utils.weight_init)

    def forward(self, obs):
        skill_pred = self.skill_pred_net(obs)
        return skill_pred


class LSDAgent(DDPGAgent):
    def __init__(self, update_skill_every_step, skill_dim, lsd_scale,
                 update_encoder, **kwargs):
        self.skill_dim = skill_dim
        self.update_skill_every_step = update_skill_every_step
        self.lsd_scale = lsd_scale
        self.update_encoder = update_encoder
        # increase obs shape to include skill dim
        kwargs["meta_dim"] = self.skill_dim

        # create actor and critic
        super().__init__(**kwargs)

        # create lsd
        self.lsd = LSD(self.obs_dim - self.skill_dim, self.skill_dim,
                           kwargs['hidden_dim']).to(kwargs['device'])

        # loss criterion
        self.lsd_criterion = nn.CrossEntropyLoss()
        # optimizers
        self.lsd_opt = torch.optim.Adam(self.lsd.parameters(), lr=self.lr)

        self.lsd.train()

    def get_meta_specs(self):
        return (specs.Array((self.skill_dim,), np.float32, 'skill'),)

    def init_meta(self):
        skill = np.zeros(self.skill_dim, dtype=np.float32)
        skill[np.random.choice(self.skill_dim)] = 1.0
        meta = OrderedDict()
        meta['skill'] = skill
        return meta

    def update_meta(self, meta, global_step, time_step):
        if global_step % self.update_skill_every_step == 0:
            return self.init_meta()
        return meta

    def update_lsd(self, skill, obs, next_obs, step):
        metrics = dict()

        loss = self.compute_lsd_loss(obs, next_obs, skill)

        self.lsd_opt.zero_grad()
        if self.encoder_opt is not None:
            self.encoder_opt.zero_grad(set_to_none=True)
        loss.backward()
        self.lsd_opt.step()
        if self.encoder_opt is not None:
            self.encoder_opt.step()

        if self.use_tb or self.use_wandb:
            metrics['lsd_loss'] = loss.item()

        return metrics

    def compute_intr_reward(self, skill, obs, next_obs, step):    
        d_pred = self.lsd(obs)  # B, skill_dim
        d_pred_next = self.lsd(next_obs)
        d_diff = d_pred_next - d_pred # B, skill_dim
        masks = (skill - skill.mean(dim=1, keepdim=True)) * self.skill_dim / (
            self.skill_dim - 1 if self.skill_dim != 1 else 1)
        reward = (d_diff * masks).sum(dim=1) # B
        reward = reward.reshape(-1, 1)

        return reward * self.lsd_scale

    def compute_lsd_loss(self, state, next_state, skill):
        """
        DF Loss
        """
        B, _ = skill.shape
        z_hat = torch.argmax(skill, dim=1)
        d_pred = self.lsd(state)
        d_pred_next = self.lsd(next_state)
        d_diff = d_pred_next - d_pred  # B, skill_dim

        masks = (skill - skill.mean(dim=1, keepdim=True)) * self.skill_dim / (
            self.skill_dim - 1 if self.skill_dim != 1 else 1)
        d_loss = -(d_diff * masks).sum(dim=1).mean()
        return d_loss

    def update(self, replay_iter, step):
        metrics = dict()

        if step % self.update_every_steps != 0:
            return metrics

        batch = next(replay_iter)
        obs, action, extr_reward, discount, next_obs, done, skill = utils.to_torch(
            batch, self.device)

        # augment and encode
        obs = self.aug_and_encode(obs)
        next_obs = self.aug_and_encode(next_obs)

        if self.reward_free:
            metrics.update(self.update_lsd(skill, obs, next_obs, step))

            with torch.no_grad():
                intr_reward = self.compute_intr_reward(skill, obs, next_obs, step)

            if self.use_tb or self.use_wandb:
                metrics['intr_reward'] = intr_reward.mean().item()
            reward = intr_reward
        else:
            reward = extr_reward

        if self.use_tb or self.use_wandb:
            metrics['extr_reward'] = extr_reward.mean().item()
            metrics['batch_reward'] = reward.mean().item()

        if not self.update_encoder:
            obs = obs.detach()
            next_obs = next_obs.detach()

        # extend observations with skill
        obs = torch.cat([obs, skill], dim=1)
        next_obs = torch.cat([next_obs, skill], dim=1)

        # update critic
        metrics.update(
            self.update_critic(obs.detach(), action, reward, discount,
                               next_obs.detach(), step))

        # update actor
        metrics.update(self.update_actor(obs.detach(), step))

        # update critic target
        utils.soft_update_params(self.critic, self.critic_target,
                                 self.critic_target_tau)

        return metrics
