r"""
    SAC stub, copied from https://github.com/denisyarats/pytorch_sac/blob/master/agent/sac.py
    TD3 stub, copied from MaskDP https://github.com/FangchenLiu/MaskDP_public

    Finally, IQL stub, copied from https://github.com/rail-berkeley/rlkit/blob/master/rlkit/torch/sac/iql_trainer.py
    - IQL is great because it's discrete/continuous agnostic


    Hmmm
    - so we need an update step that reuses the backbone.
"""
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as D
from collections import OrderedDict
from einops import repeat, rearrange

from context_general_bci.components import create_block
from context_general_bci.rl.utils import soft_update_params, get_numpy
from context_general_bci.config import TransformerConfig


class UnifiedCausalLayer(nn.Module):
    # TODO decoder layer for xent
    def __init__(
            self,
            transformer_config: TransformerConfig):
        # Output a single Q value using cross-attn.
        self.qf = create_block(config=transformer_config)
        self.q_out = nn.Parameter(torch.randn(transformer_config.n_state))

    def forward(self, states: torch.Tensor):
        r"""
            obs_and_action: B x T x H. Backbone states for timestep.
            - we unify since we don't know if we have consistent dimensionality
            TODO resolve padding - Assumed with padding.
        """
        q_out = repeat(self.q_out, 'H -> B T H', B=states.size(0), T=1)
        query = torch.cat([states, q_out], dim=-1)
        return self.qf(query)[:, -1]

class IQLAgent:
    r"""
        Remove data responsibilities, that's handled by RTMA Module
        # TODO construct init from backbone

        I need
        - An actor that can act with just a (set of backbone) states
            - Outputting an appropriate number of steps too!
        - A state value function that can do the same
        - A state-action value function that can do the same.
    """
    def __init__(
            self,
            policy,
            qf1,
            qf2,
            vf,
            quantile=0.5,
            target_qf1=None,
            target_qf2=None,
            discount=0.99,
            reward_scale=1.0,

            policy_lr=1e-3,
            value_loss_coeff=0.5,
            qf_lr=1e-3,
            policy_weight_decay=0,
            q_weight_decay=0,
            optimizer_class=optim.Adam,

            policy_update_period=1,
            q_update_period=1,

            clip_score=None,
            soft_target_tau=1e-2,
            target_update_period=1,
            beta=1.0,
    ):
        super().__init__()
        self.policy = policy
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period
        self.vf = vf
        self.value_loss_coeff = value_loss_coeff

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()
        self.optimizer = optimizer_class(
            list(self.policy.parameters()) +
            list(self.qf1.parameters()) +
            list(self.qf2.parameters()) +
            list(self.vf.parameters()),
            weight_decay=policy_weight_decay,
            lr=policy_lr,
        )
        # self.optimizers = {}

        # self.policy_optimizer = optimizer_class(
        #     self.policy.parameters(),
        #     weight_decay=policy_weight_decay,
        #     lr=policy_lr,
        # )
        # self.optimizers[self.policy] = self.policy_optimizer
        # self.qf1_optimizer = optimizer_class(
        #     self.qf1.parameters(),
        #     weight_decay=q_weight_decay,
        #     lr=qf_lr,
        # )
        # self.qf2_optimizer = optimizer_class(
        #     self.qf2.parameters(),
        #     weight_decay=q_weight_decay,
        #     lr=qf_lr,
        # )
        # self.vf_optimizer = optimizer_class(
        #     self.vf.parameters(),
        #     weight_decay=q_weight_decay,
        #     lr=qf_lr,
        # )

        # if self.z:
        #     self.z_optimizer = optimizer_class(
        #         self.z.parameters(),
        #         weight_decay=q_weight_decay,
        #         lr=qf_lr,
        #     )

        self.discount = discount
        self.reward_scale = reward_scale
        self._n_train_steps_total = 0

        self.q_update_period = q_update_period
        self.policy_update_period = policy_update_period

        self.clip_score = clip_score
        self.beta = beta
        self.quantile = quantile

    def update(self, batch, steps=1, return_info: bool=False) -> Dict[str, float]:
        r"""
            returns:
            - update_info (generic Dict)
        """
        rewards = batch['rewards'] # Note RLKit uses a reward transform, IDK what that should be for us.
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        out = {}
        for i in range(steps):
            step_out = self._inner_update(obs, actions, rewards, next_obs, return_info=return_info and i == 0)
            out.update(step_out)
        return out

    def _inner_update(
            self,
            obs: torch.Tensor, # B x K x H # K is max length neural state
            actions: torch.Tensor,
            rewards,
            next_obs,
            return_info: bool=False
        ) -> Dict[str, float]:
        r"""
            IQL adapted for unified arch.
        """
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        target_vf_pred = self.vf(next_obs).detach()

        q_target = self.reward_scale * rewards + self.discount * target_vf_pred
        q_target = q_target.detach()
        qf1_loss = self.qf_criterion(q1_pred, q_target)
        qf2_loss = self.qf_criterion(q2_pred, q_target)

        """
        VF Loss
        """
        q_pred = torch.min(
            self.target_qf1(obs, actions),
            self.target_qf2(obs, actions),
        ).detach()
        vf_pred = self.vf(obs)
        vf_err = vf_pred - q_pred
        vf_sign = (vf_err > 0).float()
        vf_weight = (1 - vf_sign) * self.quantile + vf_sign * (1 - self.quantile)
        vf_loss = (vf_weight * (vf_err ** 2)).mean()

        """
        Policy Loss
        """
        dist = self.policy(obs, torch.zeros_like(actions))
        policy_logpp = dist.log_prob(actions)
        adv = q_pred - vf_pred.detach() # Don't backprop VF for policy
        exp_adv = torch.exp(adv / self.beta)
        if self.clip_score is not None:
            exp_adv = torch.clamp(exp_adv, max=self.clip_score)

        weights = exp_adv[:, 0]
        policy_loss = (-policy_logpp * weights).mean()

        """
        Update networks
        """
        self.optimizer.zero_grad()
        total_loss = (
            (qf1_loss + qf2_loss + vf_loss) * self.value_loss_coeff \
                  + policy_loss
        )
        total_loss.backward()
        self.optimizer.step()

        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            soft_update_params(self.qf1, self.target_qf1, self.soft_target_tau)
            soft_update_params(self.qf2, self.target_qf2, self.soft_target_tau)

        out = {}
        if return_info:
            out['QF1 Loss'] = np.mean(get_numpy(qf1_loss))
            out['QF2 Loss'] = np.mean(get_numpy(qf2_loss))
            out['Policy Loss'] = np.mean(get_numpy(policy_loss))
            out['VF Loss'] = np.mean(get_numpy(vf_loss))

        self._n_train_steps_total += 1
        return out
