import argparse
import copy
import gymnasium as gym
import itertools
import logging
import numpy as np
import torch
import torch.nn as nn
from typing import Tuple

from collections import namedtuple
from . import sac, utils
from .utils import ReplayMemory, StochasticTrainingAlgorithm
from ..delayed_mdp import SimulatedInteractionLayer
from latency_env.modules.model import NNPredictiveModel

import latency_env.misc.argparser_types as at
from latency_env.misc.argparser import Argument as Arg
from latency_env.misc.argparser import ArgumentList as ArgList

LOG = logging.getLogger(__name__)
LOG.addHandler(logging.NullHandler())


def arguments_hook(args):
    if args.model_batch_size is None:
        args.model_batch_size = args.batch_size

arguments = sac.arguments + ArgList(
    _hooks=[arguments_hook],
    assumed_delay = Arg("--assumed-delay", default=1, type=at.posint,
                        help="Assumed delay in terms of constant time steps."),
    lr_model = Arg("--lr-model", default=1e-4, type=at.posfloat,
                   help="Model learning rate."),
    wd_model = Arg("--wd-model", default=0.0, type=at.nonnegfloat,
                   help="Model weight decay (L2 loss)."),
    optim_model = Arg("--optim-model", default="adam", type=str.lower,
                      help="Model network optimizer.", choices=utils.OPTIMIZERS),
    norm_clip_model = Arg("--norm-clip-model", default=None, type=float,
                          help="Model gradient norm clipping."),
    model_train_window = Arg("--mtw", "--model-train-window", default=16, type=at.posint,
                             help="Model training window, aka the horizon."),
    model_batch_size = Arg("--model-batch-size", default=None, type=at.posint,
                           _auto_default=False,
                           help="Size of minibatches when performing model network "
                                "updates (will default to regular --batch-size "
                                "if unset)."),
    model_clamp_output = Arg("--model-clamp-output", action=argparse.BooleanOptionalAction, default=False,
                             help="Clamp the outputs from the model and the environment."),
)


class BPQL_MDA_ReplayMemory(ReplayMemory):
    """
    A standard replay buffer with augmented next-states.
    """
    ParameterOrder = ReplayMemory.ParameterOrder + [
        "s_obs", "s_obs_next",
        "a_mem", "a_mem_next",
        "future_state", "future_next_state",
    ]

    def __init__(self, assumed_delay, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.rb["s_obs"] = torch.zeros((self.maxsize,) + self.s_shape, requires_grad=False, dtype=self.s_type)
        self.rb["s_obs_next"] = torch.zeros((self.maxsize,) + self.s_shape, requires_grad=False, dtype=self.s_type)
        self.rb["a_mem"] = torch.zeros((self.maxsize, assumed_delay) + self.a_shape, requires_grad=False, dtype=self.s_type)
        self.rb["a_mem_next"] = torch.zeros((self.maxsize, assumed_delay) + self.a_shape, requires_grad=False, dtype=self.s_type)
        self.rb["future_state"] = torch.zeros((self.maxsize,) + self.s_shape, requires_grad=False, dtype=self.s_type)
        self.rb["future_next_state"] = torch.zeros((self.maxsize,) + self.s_shape, requires_grad=False, dtype=self.s_type)


class BPQL_MDA(sac.SAC):
    """
    Implementation of BPQL, which is an extension of SAC.

    Uses the Model-based Distribution Agent (MDA) as its policy, instead of the
    direct function approximator.
    """
    def __init__(self,
                 pi: nn.Module,
                 q1: nn.Module,
                 q2: nn.Module,
                 env: SimulatedInteractionLayer,
                 model: NNPredictiveModel,
                 obs_metric: "pytorch differentiable function",
                 args: "arguments"):
        super().__init__(pi=pi, q1=q1, q2=q2, env=env, args=args)

        #assert isinstance(env, SimulatedInteractionLayer)
        self.model = model
        self.model = self.model.to(self.device)
        self.obs_metric = obs_metric

        self.model_opt = utils.make_optimizer(self.model, optim=args.optim_model, lr=args.lr_model, weight_decay=args.wd_model)

        # Override replay buffer with BPQL one
        self.replay = BPQL_MDA_ReplayMemory(
            maxsize=self.p.replay_size, env=self.env,
            assumed_delay=self.p.assumed_delay,
        )

    def evalcopy(self):
        blob = super().evalcopy()
        blob["model"] = copy.deepcopy(self.model).to("cpu")
        blob["obs_metric"] = self.obs_metric
        return blob

    @classmethod
    def from_evalcopy(obj, blob, device=None):
        if device is not None:
            blob = copy.copy(blob)
            blob["p"] = copy.copy(blob["p"])
            blob["p"].device = device
        return obj(
            pi=blob["pi"],
            q1=blob["q1"],
            q2=blob["q2"],
            env=blob["env"],
            model=blob["model"],
            obs_metric=blob["obs_metric"],
            args=blob["p"],
        )

    def modules(self):
        return super().modules() + (self.model,)

    @torch.no_grad()
    def clamp(self, s):
        """Clamps the state according to the metric."""
        if self.p.model_clamp_output:
            cs = self.obs_metric.clamp(torch.as_tensor(s)).numpy()
        else:
            cs = s
        return cs

    def embed_s_bar(self, s, action_trace):
        """
        Embeds the state and memorized actions into a the latent distribution
        variable from the model.
        """

        s_obs = torch.as_tensor(s, dtype=torch.float32, device=self.device)
        a_mem = torch.as_tensor(
            np.stack([a for a in action_trace[-self.p.assumed_delay:]]),
            dtype=torch.float32,
            device=self.device,
        ) # [L, A]

        embed = self.model.embed_inputs(
            s_obs.unsqueeze(0), # [1, S]
            torch.nan_to_num(a_mem).unsqueeze(0), # [1, L, A]
            omit_h0=True,
        ) # [1, L, H]
        return embed[0, self.p.assumed_delay - 1] # [H]

    @torch.no_grad()
    def sample_action(self, state_action_info: Tuple["State", "ActionTrace"], **kwargs):
        """
        Convert the delayed state into a sequence of observed states and actions.
        """
        embed = self.embed_s_bar(*state_action_info)
        return super().sample_action(embed, **kwargs)

    def evaluate(self, trajectory_length=None, env=None, **kwargs):
        """
        Evaluates the BPQL policy.
        """
        if env is None:
            env = self.env
        if trajectory_length is None:
            trajectory_length = self.p.max_trajectory

        action_trace = []

        def fn_reset():
            nonlocal action_trace
            s, info = env.reset()
            s = self.clamp(s)
            action_trace = [env.default_action.copy() for _ in range(self.p.assumed_delay)]
            return (s, action_trace), info

        def fn_step(a):
            nonlocal action_trace
            action_trace.append(a)
            s, r, term, trunc, info = env.step(a)
            s = self.clamp(s)
            return ((s, action_trace), r, term, trunc, info)

        return super().evaluate(trajectory_length=trajectory_length, env=env,
                                fn_reset=fn_reset, fn_step=fn_step, **kwargs)


    def train(self, n_episodes=None, n_steps=None, follow_policy=True):
        """
        Trains the networks contained within with BPQL.
        """

        for ep in self.progress(n_episodes=n_episodes, n_steps=n_steps):
            self.replay.new_trajectory()
            r_total = 0.0
            (s, _), done = self.env.reset(seed=ep), False
            s = self.clamp(s)

            obs_stack = []
            act_trace = [self.env.default_action.copy() for _ in range(self.p.assumed_delay)]
            for step in self.progress_steps(limit=self.p.max_trajectory):
                if done:
                    break

                with torch.no_grad():
                    if not self.update_before_sampling:
                        # Take a step
                        a, info = self.sample_action((s, act_trace), random=(not follow_policy), with_info=True)
                        s_next, r, terminated, truncated, _ = self.env.step(a)
                        s_next = self.clamp(s_next)
                        done = terminated or truncated

                        obs_stack.append((s, a, r, s_next, terminated))
                        act_trace.append(a)
                        s = s_next
                        r_total += r

                        ait = np.nditer(a, flags=["multi_index"])
                        for aval in ait:
                            idx = ("/" + ",".join(str(i) for i in ait.multi_index)) if ait.multi_index != () else ""
                            self.record_iteration_value("actor/action" + idx, aval)
                        if "std" in info:
                            stdit = np.nditer(info["std"], flags=["multi_index"])
                            for stdval in stdit:
                                idx = ("/" + ",".join(str(i) for i in stdit.multi_index)) if stdit.multi_index != () else ""
                                self.record_iteration_value("actor/action_std" + idx, stdval)

                        # Only add to action buffer once we know what the actual state was
                        if len(obs_stack) > self.p.assumed_delay:
                            # Construct the extended states:
                            i = len(obs_stack) - (self.p.assumed_delay+1)
                            rs, ra, _, rs_next, _ = obs_stack[i]
                            rs_future, _, rr, rsn_future, rterm = obs_stack[-1]

                            #LOG.debug(f"i={i}, len(obs_stack)={len(obs_stack)}, len(act_trace)={len(act_trace)}")

                            (s_obs, s_obs_next) = (rs, rs_next)
                            a_mem = np.stack(act_trace[i: i + self.p.assumed_delay])
                            a_mem_next = np.stack(act_trace[(i + 1): (i + 1) + self.p.assumed_delay])

                            # Sanity check that data matches what we would expect
                            #emb_ra = rs_future[self.env.raw_obs_len:(self.env.raw_obs_len + self.env.raw_act_len)]
                            #assert (emb_ra == ra).all(), f"emb_ra: {emb_ra}, ra: {ra}, rs: {rs}, rs_future: {rs_future}"
                            self.replay.add(
                                rs, ra, rr, rs_next, rterm,
                                s_obs, s_obs_next,
                                a_mem, a_mem_next,
                                rs_future, rsn_future,
                            )

                # Minibatch network updates
                self.steps_since_update -= 1
                if (self.steps_since_update <= 0 or self.update_before_sampling) and len(self.replay) > 0:
                    self.update_before_sampling = False
                    self.steps_since_update = self.p.update_delay
                    losses_q = np.zeros((self.p.update_iterations,), dtype=np.float32)
                    losses_pi = np.zeros((self.p.update_iterations,), dtype=np.float32)
                    losses_alpha = np.zeros((self.p.update_iterations,), dtype=np.float32)
                    losses_model = np.zeros((self.p.update_iterations,), dtype=np.float32)
                    for update_step in range(self.p.update_iterations):
                        # Sample batch for updates
                        (
                            bs, ba, br, bsn, bd,
                            bs_obs, bs_obs_next,
                            ba_mem, ba_mem_next,
                            bsf, bsnf,
                        ) = self.replay.sample(self.p.batch_size, device=self.device)

                        # Extract the future raw observation without any of the associated actions
                        #bsf = bsf[:, :self.env.raw_obs_len]
                        #bsnf = bsnf[:, :self.env.raw_obs_len]

                        with torch.no_grad():
                            # Pre-fetch the alpha for later
                            alpha = torch.exp(self.log_alpha)

                            # Embed the augmented states
                            embed = self.model.embed_inputs(
                                torch.cat((bs_obs, bs_obs_next), 0), # [2N, S]
                                torch.nan_to_num(torch.cat((ba_mem, ba_mem_next), 0)), # [2N, L, A]
                                omit_h0=True,
                            ) # [2N, L, H]
                            embed = embed[:, self.p.assumed_delay - 1] # [2N, H]

                            embed_bs_bar, embed_bsn_bar = torch.split(embed, self.p.batch_size) # ([N, H], [N, H])

                        # Q update
                        self.q_opt.zero_grad()
                        q1val = self.q1(bsf, ba)
                        if self.p.use_double_q:
                            q2val = self.q2(bsf, ba)
                        #assert q1val.shape == q2val.shape == br.shape == bd.shape, f"q1val.shape: {q1val.shape}, q2val.shape: {q2val.shape}, br.shape: {br.shape}, bd.shape: {bd.shape}"
                        with torch.no_grad():
                            if self.p.use_double_q:
                                an, _, logp_an, _ = self.pi(embed_bsn_bar)
                                qval_targ = torch.min(self.q1_targ(bsnf, an), self.q2_targ(bsnf, an))
                            else:
                                an, _, logp_an, _ = self.pi_targ(embed_bsn_bar)
                                qval_targ = self.q1_targ(bsnf, an)
                            ys = br + self.p.discount * (1 - bd) * (qval_targ - alpha * logp_an)

                        loss_q1 = ((q1val - ys)**2).mean()
                        if self.p.use_double_q:
                            loss_q2 = ((q2val - ys)**2).mean()
                            loss_q = loss_q1 + loss_q2
                        else:
                            loss_q = loss_q1
                        loss_q.backward()
                        if self.p.norm_clip_q is not None:
                            torch.nn.utils.clip_grad_norm_(self.q.parameters(), self.p.norm_clip_q)

                        self.q_opt.step()

                        # pi (policy) update
                        self.pi_opt.zero_grad()

                        a, _, logp_a, _ = self.pi(embed_bs_bar)
                        qval = self.q1(bsf, a)
                        if self.p.use_double_q:
                            qval = torch.min(qval, self.q2(bsf, a))

                        loss_pi = -(qval - alpha * logp_a).mean()
                        loss_pi.backward()
                        if self.p.norm_clip_pi is not None:
                            torch.nn.utils.clip_grad_norm_(self.pi.parameters(), self.p.norm_clip_pi)

                        self.pi_opt.step()

                        # alpha (temperature) update
                        self.alpha_opt.zero_grad()

                        # J(a) = E[-a logp(a|s) - a*H]
                        loss_alpha = -1.0 * (
                            torch.exp(self.log_alpha) * (
                                logp_a.mean() +
                                self.p.entropy_threshold
                            ).clone().detach().requires_grad_(False)
                        )
                        loss_alpha.backward()

                        self.alpha_opt.step()

                        # model update (if it can be optimized)
                        if self.model.optimizable:
                            lengths, (tbs, tba, tbsn) = self.replay.sample_trajectory(
                                n=self.p.model_batch_size,
                                k=self.p.model_train_window,
                                keys=["state", "action", "next_state"],
                                device=self.device,
                            )

                            self.model_opt.zero_grad()

                            distance = self.obs_metric

                            loss_model, _ = self.model.loss(
                                state=tbs[:,0],
                                actions=tba,
                                next_states=tbsn,
                                lengths=lengths,
                                distance=distance,
                            )

                            loss_model.backward()
                            self.model_opt.step()
                        else:
                            loss_model = torch.tensor(0.0, requires_grad=False)

                        losses_q[update_step] = loss_q.cpu().detach().numpy()
                        losses_pi[update_step] = loss_pi.cpu().detach().numpy()
                        losses_alpha[update_step] = loss_alpha.cpu().detach().numpy()
                        losses_model[update_step] = loss_model.cpu().detach().numpy()

                        # Polyak (soft target) updates
                        self.updates_since_polyak -= 1
                        if self.updates_since_polyak <= 0:
                            self.updates_since_polyak = self.p.polyak_delay
                            with torch.no_grad():
                                for train, target in [(self.pi, self.pi_targ), (self.q1, self.q1_targ), (self.q2, self.q2_targ)]:
                                    for p_train, p_targ in zip(train.parameters(), target.parameters()):
                                        p_targ.data.mul_(1 - self.p.polyak)
                                        p_targ.data.add_(self.p.polyak * p_train.data)

                    self.record_iteration_value("critic/loss", losses_q.mean())
                    self.record_iteration_value("actor/loss", losses_pi.mean())
                    self.record_iteration_value("temperature/loss", losses_alpha.mean())
                    self.record_iteration_value("temperature/value", torch.exp(self.log_alpha).item())
                    self.record_iteration_value("model/loss", losses_model.mean())

            self.record_episode_total_reward(r_total)
