import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
import hydra
import random

class Generator(nn.Module):
    """Convolutional encoder for image-based observations."""
    def __init__(self, obs_shape, feature_dim):
        super().__init__()
        assert len(obs_shape) == 3
        self.num_layers = 4
        self.num_filters = 32
        self.output_dim = 35
        self.output_logits = False
        self.feature_dim = feature_dim

        self.convs = nn.ModuleList([
            nn.Conv2d(obs_shape[0], self.num_filters, 3, stride=2),
            nn.Conv2d(self.num_filters, self.num_filters, 3, stride=1),
            nn.Conv2d(self.num_filters, self.num_filters, 3, stride=1),
            nn.Conv2d(self.num_filters, self.num_filters, 3, stride=1)
        ])

        self.head = nn.Sequential(
            nn.Linear(self.num_filters * 35 * 35, self.feature_dim),
            nn.LayerNorm(self.feature_dim))

        self.outputs = dict()

    def forward_conv(self, obs):
        if obs.max() > 1.:
            obs = obs / 255.
        
        conv = torch.relu(self.convs[0](obs))

        for i in range(1, self.num_layers):
            conv = torch.relu(self.convs[i](conv))
        
        h = conv.view(conv.size(0), -1)
        return h

    def forward(self, obs, detach=False):
        h = self.forward_conv(obs)

        if detach:
            h = h.detach()

        out = self.head(h)
        if not self.output_logits:
            out = torch.tanh(out)

        return out

    def copy_conv_weights_from(self, source):
        """Tie convolutional layers"""
        for i in range(self.num_layers):
            utils.tie_weights(src=source.convs[i], trg=self.convs[i])

    def log(self, logger, step):
        pass

class InverseForwardDynamicsModel(nn.Module):
    def __init__(self, generator_cfg, feature_dim, action_shape, hidden_dim):
        super().__init__()
        self.generator = hydra.utils.instantiate(generator_cfg)
        self.max_sigma = 1e1
        self.min_sigma = 1e-4
        
        self.fc_inverse = nn.Linear(2*feature_dim, hidden_dim)
        self.ln_inverse = nn.LayerNorm(hidden_dim)
        self.head_inverse = nn.Linear(hidden_dim, action_shape[0])
        
        self.fc_forward = nn.Linear(action_shape[0]+feature_dim, hidden_dim)
        self.ln_forward = nn.LayerNorm(hidden_dim)
        self.head_forward = nn.Linear(hidden_dim, feature_dim)

        self.apply(utils.weight_init)
    
    def forward(self, h_clean, h_next_clean, h_aug, h_next_aug):        
        joint_h_g = torch.cat([h_aug, h_next_aug], dim=1)
        joint_h_c = torch.cat([h_clean, h_next_clean], dim=1)

        pred_action_g = torch.relu(self.ln_inverse(self.fc_inverse(joint_h_g)))
        pred_action_g = torch.tanh(self.head_inverse(pred_action_g))

        pred_action_c = torch.relu(self.ln_inverse(self.fc_inverse(joint_h_c)))
        pred_action_c = torch.tanh(self.head_inverse(pred_action_c))
        
        joint_s_a_g = torch.cat([h_aug, pred_action_c], dim=1)
        joint_s_a_c = torch.cat([h_clean, pred_action_g], dim=1)

        pred_next_state_g = torch.relu(self.ln_forward(self.fc_forward(joint_s_a_g)))
        pred_next_state_g = torch.tanh(self.head_forward(pred_next_state_g))

        pred_next_state_c = torch.relu(self.ln_forward(self.fc_forward(joint_s_a_c)))
        pred_next_state_c = torch.tanh(self.head_forward(pred_next_state_c))

        return pred_action_g, pred_action_c, pred_next_state_g, pred_next_state_c


class Actor(nn.Module):
    """torch.distributions implementation of an diagonal Gaussian policy."""
    def __init__(self, generator_cfg, action_shape, hidden_dim, hidden_depth, log_std_bounds): # obs_shape, image_pad
        super().__init__()
        self.generator = hydra.utils.instantiate(generator_cfg)
        self.log_std_bounds = log_std_bounds
        self.trunk = utils.mlp(self.generator.feature_dim, hidden_dim, 2 * action_shape[0], hidden_depth)

        self.outputs = dict()        
        self.apply(utils.weight_init)

    def forward(self, obs, detach_generator=False):
        obs = self.generator(obs, detach=detach_generator)
        mu, log_std = self.trunk(obs).chunk(2, dim=-1)

        # constrain log_std inside [log_std_min, log_std_max]
        log_std = torch.tanh(log_std)
        log_std_min, log_std_max = self.log_std_bounds
        log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1)
        std = log_std.exp()

        dist = utils.SquashedNormal(mu, std)
        return dist

    def log(self, logger, step):
        pass


class Critic(nn.Module):
    """Critic network, employes double Q-learning."""
    def __init__(self, generator_cfg, action_shape, hidden_dim, hidden_depth):
        super().__init__()
        self.generator = hydra.utils.instantiate(generator_cfg)
        self.Q1 = utils.mlp(self.generator.feature_dim + action_shape[0], hidden_dim, 1, hidden_depth)
        self.Q2 = utils.mlp(self.generator.feature_dim + action_shape[0], hidden_dim, 1, hidden_depth)
        
        self.outputs = dict()
        self.apply(utils.weight_init)

    def forward(self, aug_obs, action, detach_generator=False):
        assert aug_obs.size(0) == action.size(0)
        
        aug_obs = self.generator(aug_obs, detach=detach_generator)

        obs_action = torch.cat([aug_obs, action], dim=-1)
        q1 = self.Q1(obs_action)
        q2 = self.Q2(obs_action)

        return q1, q2
        
    def log(self, logger, step):
        pass

class Discriminator(nn.Module):
    def __init__(self, feature_dim, hidden_dim, hidden_depth):
        super().__init__()
        self.fc = nn.Linear(feature_dim, hidden_dim)
        self.ln = nn.LayerNorm(hidden_dim)
        self.head = nn.Linear(hidden_dim, 1)

        self.apply(utils.weight_init)

    def forward(self, obs):

        D_critic = torch.relu(self.ln(self.fc(obs)))
        D_critic = torch.tanh(self.head(D_critic))

        return D_critic


class JS2RLAgent(object):
    def __init__(self, obs_shape, action_shape, action_range, device, generator_cfg, discriminator_cfg,
                 critic_cfg, actor_cfg, inv_cfg, discount, init_temperature, lr, actor_update_frequency, critic_tau,
                 critic_target_update_frequency, batch_size):
        self.action_range = action_range
        self.device = device
        self.discount = discount
        self.critic_tau = critic_tau
        self.actor_update_frequency = actor_update_frequency
        self.critic_target_update_frequency = critic_target_update_frequency
        self.batch_size = batch_size

        self.actor = hydra.utils.instantiate(actor_cfg).to(self.device)
        self.critic = hydra.utils.instantiate(critic_cfg).to(self.device)
        self.critic_target = hydra.utils.instantiate(critic_cfg).to(self.device)
        self.critic_target.load_state_dict(self.critic.state_dict())

        self.generator = hydra.utils.instantiate(generator_cfg).to(self.device)
        self.discriminator = hydra.utils.instantiate(discriminator_cfg).to(self.device)

        # self supervised parts
        self.inv = hydra.utils.instantiate(inv_cfg).to(self.device)
        self.inv.generator.copy_conv_weights_from(self.critic.generator)
        self.generator.copy_conv_weights_from(self.critic.generator)
        
        self.inv_optimizer = torch.optim.Adam(self.inv.parameters(), lr=lr)
        
        # tie conv layers between actor and critic
        self.actor.generator.copy_conv_weights_from(self.critic.generator)

        self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
        self.log_alpha.requires_grad = True

        # set target entropy to -|A|
        self.target_entropy = -action_shape[0]

        # optimizers
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)
        
        self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr)
        self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=lr)

        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=lr)

        self.train()
        self.critic_target.train()

    def train(self, training=True):
        self.training = training
        self.actor.train(training)
        self.critic.train(training)
        self.generator.train(training)
        self.discriminator.train(training)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def act(self, obs, sample=False):
        obs = torch.FloatTensor(obs).to(self.device)
        obs = obs.unsqueeze(0)
        dist = self.actor(obs)
        action = dist.sample() if sample else dist.mean
        action = action.clamp(*self.action_range)
        assert action.ndim == 2 and action.shape[0] == 1
        return utils.to_np(action[0])

    def update_critic(self, aug_obs_1, action, reward, aug_next_obs_1, not_done, logger, step):
        with torch.no_grad():
            dist_aug_1 = self.actor(aug_next_obs_1)
            next_action_aug_1 = dist_aug_1.rsample()
            log_prob_aug_1 = dist_aug_1.log_prob(next_action_aug_1).sum(-1, keepdim=True)

            target_Q1, target_Q2 = self.critic_target(aug_next_obs_1, next_action_aug_1)
            target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_prob_aug_1
            target_Q_aug_1 = reward + (not_done * self.discount * target_V)
            
            target_Q = target_Q_aug_1
        
        current_Q1, current_Q2 = self.critic(aug_obs_1, action)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
        
        logger.log('train_critic/loss', critic_loss, step)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        self.critic.log(logger, step)

    def update_actor_and_alpha(self, aug_obs, logger, step):
        # detach conv filters, so we don't update them with the actor loss
        dist = self.actor(aug_obs, detach_generator=True)
        action = dist.rsample()
        log_prob = dist.log_prob(action).sum(-1, keepdim=True)
        
        # detach conv filters, so we don't update them with the actor loss
        actor_Q1, actor_Q2 = self.critic(aug_obs, action, detach_generator=True)

        actor_Q = torch.min(actor_Q1, actor_Q2)

        actor_loss = (self.alpha.detach() * log_prob - actor_Q).mean()
        logger.log('train_actor/loss', actor_loss, step)
        logger.log('train_actor/target_entropy', self.target_entropy, step)
        logger.log('train_actor/entropy', -log_prob.mean(), step)

        # optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        self.actor.log(logger, step)

        self.log_alpha_optimizer.zero_grad()
        alpha_loss = (self.alpha * (-log_prob - self.target_entropy).detach()).mean()

        logger.log('train_alpha/loss', alpha_loss, step)
        logger.log('train_alpha/value', self.alpha, step)
        alpha_loss.backward()
        self.log_alpha_optimizer.step()

    def update_adv(self, clean_obs, aug_obs_1):
        self.generator_optimizer.zero_grad()
        clean_feature_g = self.generator(clean_obs)
        aug_feature_g = self.generator(aug_obs_1)
        source_imgs_critic_g = self.discriminator(clean_feature_g).detach()
        aug_imgs_critic_g = self.discriminator(aug_feature_g)
        generator_loss = - torch.mean(torch.log(torch.sigmoid(-source_imgs_critic_g + aug_imgs_critic_g)))# + F.mse_loss(aug_feature_g, clean_feature_g.detach())
        # print("generator loss :", generator_loss)
        (0.001*generator_loss).backward()
        self.generator_optimizer.step()
    
        self.discriminator_optimizer.zero_grad()
        clean_feature_d = self.generator(clean_obs)
        aug_feature_d = self.generator(aug_obs_1)
        source_imgs_critic_d = self.discriminator(clean_feature_d)
        aug_imgs_critic_d = self.discriminator(aug_feature_d.detach())

        discriminator_loss = - torch.mean(torch.log(torch.sigmoid(source_imgs_critic_d - aug_imgs_critic_d)))
        # print("discriminator loss :", discriminator_loss)
        (0.001*discriminator_loss).backward()
        self.discriminator_optimizer.step()

    def update_inv(self, clean_obs, clean_next_obs, aug_obs, aug_next_obs, action):
        h_clean, h_next_clean, h_aug, h_next_aug = self.generator(clean_obs), self.generator(clean_next_obs), self.generator(aug_obs), self.generator(aug_next_obs)
        pred_action_g, pred_action_c, pred_next_state_g, pred_next_state_c = self.inv(h_clean, h_next_clean, h_aug, h_next_aug)
        
        inv_loss = F.mse_loss(pred_action_g, action.detach()) + F.mse_loss(pred_action_c, action.detach())
        forward_loss = - 0.5 * (F.cosine_similarity(pred_next_state_g, h_next_aug.detach(), dim=-1).mean() + F.cosine_similarity(pred_next_state_c, h_next_clean.detach(), dim=-1).mean())
        total_loss = 0.5 * inv_loss + forward_loss

        self.inv_optimizer.zero_grad()
        (0.1*total_loss).backward()
        self.inv_optimizer.step()

    def update(self, replay_buffer, logger, step):
        clean_obs, action, reward, clean_next_obs, not_done, aug_obs_1, aug_next_obs_1 = replay_buffer.sample(self.batch_size)
        self.update_adv(clean_obs, aug_obs_1)
        self.update_inv(clean_obs, clean_next_obs, aug_obs_1, aug_next_obs_1, action)
        self.update_critic(clean_obs, action, reward, clean_next_obs, not_done, logger, step)
        if step % self.actor_update_frequency == 0:
            self.update_actor_and_alpha(clean_obs, logger, step)
        if step % self.critic_target_update_frequency == 0:
            utils.soft_update_params(self.critic, self.critic_target, self.critic_tau)

    def save(self, model_dir, step):
        torch.save(
            self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
        )
        torch.save(
            self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
        )
        torch.save(
            self.generator.state_dict(), '%s/generator_%s.pt' % (model_dir, step)
        )

    def load(self, model_dir, step):
        self.actor.load_state_dict(
            torch.load('%s/actor_%s.pt' % (model_dir, step))
        )
        self.critic.load_state_dict(
            torch.load('%s/critic_%s.pt' % (model_dir, step))
        )
        self.generator.load_state_dict(
            torch.load('%s/critic_%s.pt' % (model_dir, step))
        )