#!/usr/bin/env python

import torch
import cherry
import learn2learn as l2l
import dataclasses
import dotmap

from cherry.algorithms import PPO

import random_shifts_aug as rsa  # from DrQ-v2's GitHub


class EncoderProjector(torch.nn.Module):

    def __init__(self, size=512, project=True, projector_layers=2):
        super(EncoderProjector, self).__init__()
        if project:
            layers = [torch.nn.Linear(size, size, bias=False), ]
            for _ in range(projector_layers-1):
                layers.append(torch.nn.ReLU())
                layers.append(torch.nn.Linear(size, size))
            self.projector = torch.nn.Sequential(*layers)
        else:
            self.projector = lambda x: x
        self.normalize = torch.nn.functional.normalize

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = self.normalize(x)
        x = self.projector(x)
        return x


@dataclasses.dataclass
class PiSCO(cherry.algorithms.AlgorithmArguments):

    num_steps: int = 320
    batch_size: float = 64
    policy_clip: float = 0.2
    value_clip: float = 0.2
    value_weight: float = 0.5
    entropy_weight: float = 0.0
    ssl_weight: float = 0.01
    discount: float = 0.99
    gae_tau: float = 0.95
    gradient_norm: float = 0.5  # no clipping if 0
    eps: float = 1e-8

    def update(
        self,
        replay,
        optimizer,
        policy,
        value_fn,
        projector=None,
        **kwargs,
    ):
        # Log debugging values
        stats = dotmap.DotMap()

        # Unpack arguments and variables
        config = self.unpack_config(self, kwargs)
        vectorized = replay.vectorized
        data_augmentation = rsa.RandomShiftsAug(4)

        # Process replay
        all_states = replay.state()
        all_actions = replay.action()
        all_dones = replay.done()
        all_rewards = replay.reward()
        with torch.no_grad():
            if vectorized:
                state_shape = all_states.shape[2:]
                action_shape = all_actions.shape[2:]
                all_log_probs = policy.log_prob(
                    all_states.reshape(-1, *state_shape),
                    all_actions.reshape(-1, *action_shape)
                )
                # reshape to -1 here because maybe Normal distribution.
                all_log_probs = all_log_probs.reshape(*all_states.shape[:2], -1)
                all_values = value_fn(all_states.reshape(-1, *state_shape))
                all_values = all_values.reshape(*all_states.shape[:2], 1)
            else:
                all_log_probs = policy.log_prob(all_states, all_actions)
                all_values = value_fn(all_states)

        # Compute advantages and returns
        next_state_value = value_fn(replay[-1].next_state)
        all_advantages = cherry.pg.generalized_advantage(
            config.discount,
            config.gae_tau,
            all_rewards,
            all_dones,
            all_values,
            next_state_value,
        )

        returns = all_advantages + all_values
        all_advantages = cherry.normalize(all_advantages, epsilon=config.eps)

        for i, sars in enumerate(replay):
            sars.log_prob = cherry.totensor(all_log_probs[i].detach())
            sars.value = cherry.totensor(all_values[i].detach())
            sars.advantage = cherry.totensor(all_advantages[i].detach())
            sars.retur = cherry.totensor(returns[i].detach())

        # Logging
        policy_losses = []
        entropies = []
        value_losses = []
        ssl_losses = []

        # avoids the weird shapes later in the loop and extra forward passes.
        replay = replay.flatten()

        # Perform some optimization steps
        for step in range(config.num_steps):
            batch = replay.sample(config.batch_size)
            states = batch.state()
            advantages = batch.advantage()

            new_densities = policy(states)
            new_values = value_fn(states)

            # Compute PPO losses
            new_log_probs = new_densities.log_prob(batch.action())
            entropy = new_densities.entropy().mean()
            policy_loss = PPO.policy_loss(
                new_log_probs,
                batch.log_prob(),
                advantages,
                clip=config.policy_clip,
            )
            value_loss = PPO.state_value_loss(
                new_values,
                batch.value(),
                batch.retur(),
                clip=config.value_clip,
            )

            # Compute PiSCO losses
            z1 = policy.features(data_augmentation(states))
            z2 = policy.features(data_augmentation(states))
            p1 = projector(z1)
            p2 = projector(z2)
            pi_z1 = cherry.distributions.Categorical(logits=policy.actor(z1))
            pi_p1 = cherry.distributions.Categorical(logits=policy.actor(p1))
            pi_z2 = cherry.distributions.Categorical(logits=policy.actor(z2))
            pi_p2 = cherry.distributions.Categorical(logits=policy.actor(p2))
            kl1 = torch.distributions.kl_divergence(
                p=l2l.detach_distribution(pi_z1),
                q=pi_p2,
            ).mean()
            kl2 = torch.distributions.kl_divergence(
                p=l2l.detach_distribution(pi_z2),
                q=pi_p1,
            ).mean()
            ssl_loss = (kl1 + kl2) / 2.0

            loss = policy_loss
            loss = loss + config.value_weight * value_loss
            loss = loss - config.entropy_weight * entropy
            loss = loss + config.ssl_weight * ssl_loss

            # Take optimization step
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                policy.parameters(),
                config.gradient_norm,
            )
            optimizer.step()

            policy_losses.append(policy_loss)
            entropies.append(entropy)
            value_losses.append(value_loss)
            ssl_losses.append(ssl_loss)

        # Log metrics
        stats['ppo/policy_loss'] = PPO._mean(policy_losses)
        stats['ppo/entropies'] = PPO._mean(entropies)
        stats['ppo/value_loss'] = PPO._mean(value_losses)
        stats['ppo/ssl_loss'] = PPO._mean(ssl_losses).item()

        return stats
