#!/usr/bin/env python

import torch
import cherry
import dataclasses

import random_shifts_aug as rsa  # from DrQ-v2's GitHub


class EncoderProjector(torch.nn.Module):

    def __init__(self, size=512, project=True, projector_layers=2):
        super(EncoderProjector, self).__init__()
        if project:
            layers = [torch.nn.Linear(size, size, bias=False), ]
            for _ in range(projector_layers-1):
                layers.append(torch.nn.ReLU())
                layers.append(torch.nn.Linear(size, size))
            self.projector = torch.nn.Sequential(*layers)
        else:
            self.projector = lambda x: x
        self.normalize = torch.nn.functional.normalize

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = self.normalize(x)
        x = self.projector(x)
        return x


@dataclasses.dataclass
class DrQv2FastPiSCO(cherry.algorithms.AlgorithmArguments):

    batch_size: int = 512
    discount: float = 0.99
    policy_delay: int = 2
    target_delay: int = 2
    target_polyak_weight: float = 0.01
    use_automatic_entropy_tuning: bool = True
    ssl_weight: float = 0.01
    num_iters: int = 1
    nsteps: int = 1
    std_decay: float = 0.0  # decent value: 0.9999954
    min_std: float = 0.1
    projector_layers: int = 1

    def __call__(
        self,
        replay,
        policy,
        qvalue,
        target_value,
        policy_optimizer,
        features_optimizer,
        value_optimizer,
        features=None,
        encoder=None,
        encoder_optimizer=None,
        update_policy=True,
        update_target=False,
        update_value=True,
        device=None,
        replay_loader=None,
        **kwargs,
    ):
        # Log debugging values
        stats = {}

        augmentation_transform = rsa.RandomShiftsAug(4)

        for iteration in range(self.num_iters):

            # fetch batch
            batch = replay.sample(self.batch_size, nsteps=self.nsteps, discount=self.discount)
            ssl_states = states = batch.state().to(device, non_blocking=True).float()
            next_states = batch.next_state().to(device, non_blocking=True).float()
            actions = batch.action().to(device, non_blocking=True)
            rewards = batch.reward().to(device, non_blocking=True)
            dones = batch.done().to(device, non_blocking=True)

            # Process states
            states = augmentation_transform(states)
            next_states = augmentation_transform(next_states)
            if features is not None:
                states = features(states)
                next_states = features(next_states)

            # Update Policy
            if update_policy:
                new_actions = policy(states.detach()).rsample()
                policy_loss = - qvalue(states.detach(), new_actions).mean()

                policy_optimizer.zero_grad()
                policy_loss.clamp(-1e3, 1e3).backward()
                policy_optimizer.step()
                stats['policy/loss'] = policy_loss.item()

            # Update Q-function
            if update_value:
                qf1_estimate, qf2_estimate = qvalue.twin_values(
                    states,
                    actions.detach(),
                )

                # compute targets
                with torch.no_grad():
                    next_actions = policy(next_states).sample()
                    target_q = target_value(next_states, next_actions)

                target_q = rewards + (1. - dones) * self.discount * target_q
                critic_qf1_loss = (qf1_estimate - target_q).pow(2).mean()
                critic_qf2_loss = (qf2_estimate - target_q).pow(2).mean()
                value_loss = (critic_qf1_loss + critic_qf2_loss) / 2.0

                # SSL objective
                z1 = features(augmentation_transform(ssl_states))
                z2 = features(augmentation_transform(ssl_states))
                p1 = encoder(z1)
                p2 = encoder(z1)
                pi_z1 = policy(z1)
                pi_p1 = policy(p1)
                pi_z2 = policy(z2)
                pi_p2 = policy(p2)

                # KL for Gaussian with fixed covariance
                kl1 = torch.norm(
                    pi_z1.mean().detach() - pi_p2.mean(),
                    p=2,
                    dim=1,
                ).mean()
                kl2 = torch.norm(
                    pi_z2.mean().detach() - pi_p1.mean(),
                    p=2,
                    dim=1,
                ).mean()
                ssl_loss = (kl1 + kl2) / 2.0

                # Update Critic Networks
                value_optimizer.zero_grad()
                features_optimizer.zero_grad()
                if encoder_optimizer is not None:
                    encoder_optimizer.zero_grad()
                (value_loss + self.ssl_weight * ssl_loss).backward()
                value_optimizer.step()
                features_optimizer.step()
                if encoder_optimizer is not None:
                    encoder_optimizer.step()

                stats['qf/loss1'] = critic_qf1_loss.item()
                stats['qf/loss2'] = critic_qf2_loss.item()
                stats['train/ssl_loss'] = ssl_loss.item()
                stats['batch_rewards'] = rewards.mean().item()

            # Move target approximator parameters towards critic
            if update_target:
                cherry.models.polyak_average(
                    source=target_value,
                    target=qvalue,
                    alpha=self.target_polyak_weight,
                )

            # reduce std of policy if necessary
            if self.std_decay > 0.0 and hasattr(policy, 'std'):
                policy.std = max(self.min_std, self.std_decay * policy.std)

        return stats
