import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import utils
from agent.ddpg import DDPGAgent
from utils import ProbabilisticTransitionModel, StateRewardDecoder

EPSILON = 1e-4

def _sqrt(x, tol=0.):
    tol = torch.zeros_like(x)
    return torch.sqrt(torch.maximum(x, tol))

def cosine_distance(x, y):
    numerator = torch.sum(x * y, dim=-1, keepdim=True)
    # print("numerator", numerator.shape, numerator)
    denominator = torch.sqrt(
        torch.sum(x.pow(2.), dim=-1, keepdim=True)) * torch.sqrt(torch.sum(y.pow(2.), dim=-1, keepdim=True))
    cos_similarity = numerator / (denominator + EPSILON)

    return torch.atan2(_sqrt(1. - cos_similarity.pow(2.)), cos_similarity)


class TEA(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim):
        super().__init__()

        self.forward_net = nn.Sequential(
            nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, obs_dim))

        self.backward_net = nn.Sequential(nn.Linear(2 * obs_dim, hidden_dim),
                                          nn.ReLU(),
                                          nn.Linear(hidden_dim, action_dim),
                                          nn.Tanh())

        self.apply(utils.weight_init)

    def forward(self, obs, action, next_obs):
        assert obs.shape[0] == next_obs.shape[0]
        assert obs.shape[0] == action.shape[0]

        next_obs_hat = self.forward_net(torch.cat([obs, action], dim=-1))
        action_hat = self.backward_net(torch.cat([obs, next_obs], dim=-1))

        forward_error = torch.norm(next_obs - next_obs_hat,
                                   dim=-1,
                                   p=2,
                                   keepdim=True)
        backward_error = torch.norm(action - action_hat,
                                    dim=-1,
                                    p=2,
                                    keepdim=True)

        return forward_error, backward_error


class TEAAgent(DDPGAgent):
    def __init__(self, icm_scale, update_encoder, **kwargs):
        super().__init__(**kwargs)
        self.icm_scale = icm_scale
        self.update_encoder = update_encoder

        self.icm = TEA(self.obs_dim, self.action_dim,
                       self.hidden_dim).to(self.device)

        # optimizers
        self.icm_opt = torch.optim.Adam(self.icm.parameters(), lr=self.lr)

        ########### BSM ##########
        self.transition_model = ProbabilisticTransitionModel(
            self.obs_dim, self.action_dim, layer_width=64
        ).to(self.device)
        self.reward_decoder = nn.Sequential(
            nn.Linear(self.obs_dim, 64),
            nn.LayerNorm(64),
            nn.ReLU(),
            nn.Linear(64, 1)).to(self.device)
        self.state_reward_decoder = StateRewardDecoder(
            self.obs_dim, self.action_dim).to(self.device)
        
        self.decoder_optimizer = torch.optim.Adam(
            list(self.reward_decoder.parameters()) + list(self.transition_model.parameters())
            + list(self.state_reward_decoder.parameters()),
            lr=self.lr , weight_decay=0.0
        )
        #self.encoder_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.lr)
        ########################## 
        self.icm.train()

    def update_icm(self, obs, action, next_obs, step):
        metrics = dict()

        forward_error, backward_error = self.icm(obs, action, next_obs)

        loss = forward_error.mean() + backward_error.mean()

        self.icm_opt.zero_grad(set_to_none=True)
        if self.encoder_opt is not None:
            self.encoder_opt.zero_grad(set_to_none=True)
        loss.backward()
        self.icm_opt.step()
        if self.encoder_opt is not None:
            self.encoder_opt.step()

        metrics['icm_loss'] = loss.item()

        return metrics

    def aux_supervisor(self, h, action, reward, next_h, a_next, r_next, discount, step):
        metrics = dict()
        state_loss, reward_loss = self.update_transition_reward_model(h, action, next_h, reward)
        bsm_loss, prer_loss, self.ri, r_dist = self.update_encoder_v5(h, action, reward, next_h, a_next, r_next, discount, step)
        total_loss = 0.5 * (bsm_loss + prer_loss) + 1e-4 * (state_loss + reward_loss)

        #self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()
        total_loss.backward()
        #self.encoder_optimizer.step()
        self.decoder_optimizer.step()
        #metrics['intr_reward'] = self.ri.mean().item()
        metrics['r_dist'] = r_dist.mean().item()
        metrics['bsm_loss'] = bsm_loss.item()
        metrics['prer_loss'] = prer_loss.item()
        metrics['state_loss'] = state_loss.item()
        metrics['reward_loss'] = reward_loss.item()
        return metrics
    
    def update_transition_reward_model(self, h, action, next_h, reward):
        pred_next_latent_mu, pred_next_latent_sigma = self.transition_model(torch.cat([h, action], dim=1))
        if pred_next_latent_sigma is None:
            pred_next_latent_sigma = torch.ones_like(pred_next_latent_mu)

        diff = (pred_next_latent_mu - next_h.detach()) / pred_next_latent_sigma
        state_loss = torch.mean(0.5 * diff.pow(2) + torch.log(pred_next_latent_sigma))

        pred_next_latent = self.transition_model.sample_prediction(torch.cat([h, action], dim=1))
        pred_next_reward = self.reward_decoder(pred_next_latent)
        reward_loss = F.mse_loss(pred_next_reward, reward)

        return state_loss, reward_loss
    
    def update_encoder_v4(self, h, action, reward, next_h, a_next, r_next, discount, step): # V4
        batch_size = h.size(0)
        state_dim = h.size(1)
        perm = np.random.permutation(batch_size)
        h2 = h[perm]

        reward_mu, reward_sigma = self.state_reward_decoder(torch.cat([h, action], dim=1))
        loss_reward_decoder = self.state_reward_decoder.loss(
                reward_mu, reward_sigma, reward
            )
        with torch.no_grad():
            pred_next_latent_mu1, pred_next_latent_sigma1 = self.transition_model(torch.cat([h, action], dim=1))
            pred_next_latent_mu11, pred_next_latent_sigma11 = self.transition_model(torch.cat([next_h, a_next], dim=1))
            prer = self.state_reward_decoder.sample_prediction(torch.cat([h, action], dim=1))
            prer = prer.detach()
            next_prer = self.state_reward_decoder.sample_prediction(torch.cat([next_h, a_next], dim=1))
            next_prer = next_prer.detach()
            reward2 = reward[perm]
            prer2 = prer[perm]

        pred_next_latent_mu2 = pred_next_latent_mu1[perm] # B, state_dim
        pred_next_latent_sigma2 = pred_next_latent_sigma1[perm]

        z_dist = self.metric_func(h, h2)

        r_distt0 = F.smooth_l1_loss(reward, reward2, reduction='none')
        r_distt10 = F.smooth_l1_loss(r_next, reward2, reduction='none')

        r_distt0 = F.smooth_l1_loss(reward, reward2.mean().unsqueeze(-1), reduction='none')
        r_distt10 = F.smooth_l1_loss(r_next, reward2.mean().unsqueeze(-1), reduction='none')
        pred_next_latent_mu22 = pred_next_latent_mu2.mean(dim=0, keepdim=True).expand(batch_size, state_dim)
        transition_dist = self.metric_func(pred_next_latent_mu1, pred_next_latent_mu22)
        transition_distn = self.metric_func(pred_next_latent_mu11, pred_next_latent_mu22)

        r_int = (discount * (discount * transition_distn + r_distt10) - (discount * transition_dist + r_distt0)).pow(2.)
        with torch.no_grad():
            r_dist_square = (prer - prer2).pow(2.) 
            #r_dist_square = (reward - reward2).pow(2.) 

            r_dist = r_dist_square
        transition_dist = self.metric_func(pred_next_latent_mu1, pred_next_latent_mu2)
        loss = F.smooth_l1_loss((z_dist - discount * transition_dist).pow(2.), r_dist, reduction='mean')
        return loss, loss_reward_decoder, r_int.detach(), r_dist

    def update_encoder_v5(self, h, action, reward, next_h, a_next, r_next, discount, step): #v5
        # h = self.critic.trunk(self.encoder(obs))  
        # next_h = self.critic.trunk(self.encoder(next_obs))   
        batch_size = h.size(0)
        perm = np.random.permutation(batch_size)
        h2 = h[perm]

        reward_mu, reward_sigma = self.state_reward_decoder(torch.cat([h, action], dim=1))
        loss_reward_decoder = self.state_reward_decoder.loss(
                reward_mu, reward_sigma, reward
            )
        with torch.no_grad():
            pred_next_latent_mu, pred_next_latent_sigma = self.transition_model(torch.cat([h, action], dim=1))
            pred_next_latent_mu1, pred_next_latent_sigma1 = self.transition_model(torch.cat([pred_next_latent_mu, a_next], dim=1)) # hat_P'
            
            pred_next_latent_mu2, pred_next_latent_sigma2 = self.transition_model(torch.cat([next_h, a_next], dim=1))

            prer = self.state_reward_decoder.sample_prediction(torch.cat([h, action], dim=1))
            prer = prer.detach()

            r_next2 = r_next[perm]
            prer2 = prer[perm]

        pred_next_latent_mup = pred_next_latent_mu[perm]
        pred_next_latent_mu22 = pred_next_latent_mu2[perm]

        z_dist = self.metric_func(h, h2)
        if True:
            r_distt0 = F.smooth_l1_loss(prer, r_next2.mean().unsqueeze(-1), reduction='none')
            r_distt10 = F.smooth_l1_loss(r_next, r_next2.mean().unsqueeze(-1), reduction='none')
            pred_next_latent_mu22 = pred_next_latent_mu22.mean(dim=0, keepdim=True).expand(batch_size, 2)
        else:
            r_distt0 = F.smooth_l1_loss(prer, r_next2, reduction='none')
            r_distt10 = F.smooth_l1_loss(r_next, r_next2, reduction='none')

        transition_dist = self.metric_func(pred_next_latent_mu1, pred_next_latent_mu22)
        transition_distn = self.metric_func(pred_next_latent_mu2, pred_next_latent_mu22)

        r_int = (discount * (discount * transition_distn + r_distt10) - (discount * transition_dist + r_distt0)).pow(2.)
        with torch.no_grad():
            r_dist_square = (prer - prer2).pow(2.) 
            #r_dist_square = (reward - reward2).pow(2.) 
            r_dist = r_dist_square #- r_var - r_var2
        w_dist = self.metric_func(pred_next_latent_mu, pred_next_latent_mup)
        loss = F.smooth_l1_loss((z_dist - discount * w_dist).pow(2.), r_dist, reduction='mean')
        return loss, loss_reward_decoder, r_int.detach(), r_dist

    def metric_func(self, x, y, distance="mico_angular"):
        dist = F.smooth_l1_loss(x, y, reduction='none')
        # beta = 1e-6 #1e-5 # #0.1
        # base_distances = cosine_distance(x, y)
        # norm_average = (x.pow(2.).sum(dim=-1, keepdim=True)
        #     + y.pow(2.).sum(dim=-1, keepdim=True))
        # dist = norm_average + beta * base_distances
        return dist

    def compute_intr_reward(self, obs, action, next_obs, step):
        forward_error, _ = self.icm(obs, action, next_obs)

        reward = forward_error * self.icm_scale
        reward = torch.log(reward + 1.0)
        return reward

    def update(self, replay_iter, step):
        metrics = dict()

        if step % self.update_every_steps != 0:
            return metrics

        batch = next(replay_iter)
        obs, action, extr_reward, discount, next_obs, a_next, r_next, done = utils.to_torch(
            batch, self.device)

        # augment and encode
        obs = self.aug_and_encode(obs)
        with torch.no_grad():
            next_obs = self.aug_and_encode(next_obs)
        if self.reward_free:
            metrics.update(self.aux_supervisor(obs, action, extr_reward, next_obs, a_next, r_next, discount, step))
            
            with torch.no_grad():
                intr_reward = self.ri * 5.0

            metrics['intr_reward'] = intr_reward.mean().item()
            reward = intr_reward
        else:
            reward = extr_reward

        metrics['extr_reward'] = extr_reward.mean().item()
        metrics['batch_reward'] = reward.mean().item()

        if not self.update_encoder:
            obs = obs.detach()
            next_obs = next_obs.detach()

        # update critic
        metrics.update(
            self.update_critic(obs.detach(), action, reward, discount,
                               next_obs.detach(), step))

        # update actor
        metrics.update(self.update_actor(obs.detach(), step))

        # update critic target
        utils.soft_update_params(self.critic, self.critic_target,
                                 self.critic_target_tau)

        return metrics
