""" 
Learn an ensemble of q with sarsa and pi_beta with NLL
"""
from collections import OrderedDict, namedtuple
import copy
from typing import Tuple
from rlkit.data_management.hdf5_path_loader import (
    d4rl_qlearning_dataset_with_next_actions,
    load_hdf5_next_actions_and_val_data,
)

import numpy as np
import torch
import torch.optim as optim
from rlkit.core.loss import LossFunction, LossStatistics
from torch import nn as nn
from rlkit.launchers.pipeline.helpers import (
    create_algorithm,
    create_dataset_next_actions,
    create_eval_env,
    create_eval_path_collector,
    create_policy,
    create_replay_buffer,
    create_trainer,
    load_demos,
    create_q,
    offline_init,
    optionally_normalize_dataset,
    train,
)

import rlkit.torch.pytorch_util as ptu
from rlkit.core.logging.eval_util import create_stats_ordered_dict
from rlkit.torch.torch_rl_algorithm import TorchTrainer
from rlkit.launchers.pipeline import Pipeline, Pipelines, PipelineCtx
import rlkit.core.gtimer as gt

SarsaLosses = namedtuple(
    "SarsaLosses",
    "qfs_loss",
)


class SarsaTrainer(TorchTrainer, LossFunction):
    def __init__(
        self,
        eval_env,
        qfs,
        target_qfs,
        discount=0.99,
        reward_scale=1.0,
        qf_lr=1e-3,
        optimizer_class=optim.Adam,
        soft_target_tau=1e-2,
        target_update_period=1,
        plotter=None,
        render_eval_paths=False,
        **kwargs,
    ):
        super().__init__()
        self.env = eval_env
        self.qfs = qfs
        self.target_qfs = target_qfs
        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.qfs_optimizer = optimizer_class(
            self.qfs.parameters(),
            lr=qf_lr,
        )

        self.discount = discount
        self.reward_scale = reward_scale
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
        self.eval_statistics = OrderedDict()

        self.val_obs = None
        self.val_actions = None
        self.gt_val_q_values = None

    def set_val_data(self, val_obs, val_actions, gt_val_q_values):
        self.val_obs = val_obs
        self.val_actions = val_actions
        self.gt_val_q_values = gt_val_q_values

    def train_from_torch(self, batch):
        losses, stats = self.compute_loss(
            batch,
            skip_statistics=not self._need_to_update_eval_statistics,
        )
        """
        Update networks
        """
        self.qfs_optimizer.zero_grad()
        losses.qfs_loss.backward()
        self.qfs_optimizer.step()

        self._n_train_steps_total += 1

        self.try_update_target_networks()
        if self._need_to_update_eval_statistics:
            self.eval_statistics = stats
            # Compute statistics using only one batch per epoch
            self._need_to_update_eval_statistics = False

    def try_update_target_networks(self):
        if self._n_train_steps_total % self.target_update_period == 0:
            self.update_target_networks()

    def update_target_networks(self):
        ptu.soft_update_from_to(self.qfs, self.target_qfs, self.soft_target_tau)

    def compute_loss(
        self,
        batch,
        skip_statistics=False,
    ) -> Tuple[SarsaLosses, LossStatistics]:
        rewards = batch["rewards"]
        terminals = batch["terminals"]
        obs = batch["observations"]
        actions = batch["actions"]
        next_obs = batch["next_observations"]
        next_actions = batch["next_actions"]
        """
        QF Loss
        """
        q_preds = self.qfs(obs, actions)  # [512, 1, 10]
        with torch.no_grad():
            target_q_values = self.target_qfs(next_obs, next_actions)
            terminals = terminals.unsqueeze(-1).expand(-1, -1, self.qfs.num_heads)
            rewards = rewards.unsqueeze(-1).expand(-1, -1, self.qfs.num_heads)
            q_target = (
                self.reward_scale * rewards
                + (1.0 - terminals) * self.discount * target_q_values
            )

        qfs_loss = torch.sum(
            torch.mean(
                (q_preds - q_target) ** 2,
                dim=0,
            )
        )

        """
        Save some statistics for eval
        """
        eval_statistics = OrderedDict()
        if not skip_statistics:
            eval_statistics["QF Loss"] = ptu.get_numpy(qfs_loss)

            eval_statistics.update(
                create_stats_ordered_dict(
                    "Mean Q Predictions",
                    ptu.get_numpy(q_preds.mean(dim=-1)),
                )
            )
            eval_statistics.update(
                create_stats_ordered_dict(
                    "Mean Target Q Predictions",
                    ptu.get_numpy(target_q_values.mean(dim=-1)),
                )
            )

            eval_statistics.update(
                create_stats_ordered_dict(
                    "Q std",
                    np.mean(
                        ptu.get_numpy(torch.std(q_preds, dim=-1)),
                    ),
                ),
            )

            if self.gt_val_q_values is not None:
                with torch.no_grad():
                    pred_val_q = self.qfs(self.val_obs, self.val_actions).mean(-1)
                    val_loss = torch.mean(
                        (pred_val_q - self.gt_val_q_values) ** 2,
                    )
                    eval_statistics["Val QF Loss"] = ptu.get_numpy(val_loss)

        loss = SarsaLosses(
            qfs_loss=qfs_loss,
        )

        return loss, eval_statistics

    def get_diagnostics(self):
        stats = super().get_diagnostics()
        stats.update(self.eval_statistics)
        return stats

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        return [
            self.qfs,
            self.target_qfs,
        ]

    @property
    def optimizers(self):
        return [
            self.qfs_optimizer,
        ]

    def get_diagnostics(self):
        stats = super().get_diagnostics()
        stats.update(self.eval_statistics)
        return stats

    def get_snapshot(self):
        return dict(
            qfs=self.qfs,
            target_qfs=self.target_qfs,
        )


"""
Pipeline code
"""


SarsaPipeline = Pipeline(
    "SarsaPipeline",
    [
        offline_init,
        create_eval_env,
        create_dataset_next_actions,
        optionally_normalize_dataset,
        create_q,
        create_policy,
        create_trainer,
        create_eval_path_collector,
        create_replay_buffer,
        create_algorithm,
        load_demos,
        train,
    ],
)


def create_q(ctx: PipelineCtx):
    obs_dim = ctx.eval_env.observation_space.low.size
    action_dim = ctx.eval_env.action_space.low.size

    qfs = ctx.variant["qf_class"](
        input_size=obs_dim + action_dim, output_size=1, **ctx.variant["qf_kwargs"]
    )

    target_qfs = ctx.variant["qf_class"](
        input_size=obs_dim + action_dim, output_size=1, **ctx.variant["qf_kwargs"]
    )

    ctx.qfs = qfs
    ctx.target_qfs = target_qfs


EnsembleSarsaPipeline = Pipeline.from_(
    SarsaPipeline, "EnsembleSarsaPipeline"
)
EnsembleSarsaPipeline.replace("create_q", create_q)


def load_demos_and_val_data(ctx: PipelineCtx):
    ctx.replay_buffer, val_obs, val_actions = load_hdf5_next_actions_and_val_data(
        ctx.dataset,
        ctx.replay_buffer,
        ctx.variant["train_ratio"],
        ctx.variant["fold_idx"],
    )

    action_space = ctx.eval_env._wrapped_env.action_space
    rg = ptu.from_numpy(action_space.high - action_space.low) / 2
    center = ptu.from_numpy(action_space.high + action_space.low) / 2

    obs_mean = ptu.from_numpy(ctx.eval_env._obs_mean)
    obs_std = ptu.from_numpy(ctx.eval_env._obs_std)

    val_obs_unnormalized = val_obs * obs_std[None] + obs_mean[None]
    val_actions_unnormalized = val_actions * rg[None] + center[None]

    params = torch.load(
        f'data/sarsa/normal/{ctx.variant["env_id"]}/itr_0.pt',
        map_location="cpu",
    )

    with torch.no_grad():
        gt_qfs = params["trainer/qfs"].to(ptu.device)
        gt_val_q_values = gt_qfs(val_obs_unnormalized, val_actions_unnormalized)
        gt_val_q_values = gt_val_q_values.mean(-1)

    ctx.trainer.set_val_data(val_obs, val_actions, gt_val_q_values)


EnsembleSarsaWithValPipeline = Pipeline.from_(
    EnsembleSarsaPipeline, "EnsembleSarsaWithValPipeline"
)
EnsembleSarsaWithValPipeline.replace("load_demos", load_demos_and_val_data)
